connectorx/sources/mssql/
mod.rs

1//! Source implementation for SQL Server.
2
3mod errors;
4mod typesystem;
5
6pub use self::errors::MsSQLSourceError;
7pub use self::typesystem::{FloatN, IntN, MsSQLTypeSystem};
8use crate::constants::DB_BUFFER_SIZE;
9use crate::{
10    data_order::DataOrder,
11    errors::ConnectorXError,
12    sources::{PartitionParser, Produce, Source, SourcePartition},
13    sql::{count_query, CXQuery},
14    utils::DummyBox,
15};
16use anyhow::anyhow;
17use bb8::{Pool, PooledConnection};
18use bb8_tiberius::ConnectionManager;
19use chrono::{DateTime, Utc};
20use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
21use fehler::{throw, throws};
22use futures::StreamExt;
23use log::debug;
24use owning_ref::OwningHandle;
25use rust_decimal::Decimal;
26use sqlparser::dialect::MsSqlDialect;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
30use tokio::runtime::{Handle, Runtime};
31use url::Url;
32use urlencoding::decode;
33use uuid_old::Uuid;
34
35type Conn<'a> = PooledConnection<'a, ConnectionManager>;
36pub struct MsSQLSource {
37    rt: Arc<Runtime>,
38    pool: Pool<ConnectionManager>,
39    origin_query: Option<String>,
40    queries: Vec<CXQuery<String>>,
41    names: Vec<String>,
42    schema: Vec<MsSQLTypeSystem>,
43}
44
45#[throws(MsSQLSourceError)]
46pub fn mssql_config(url: &Url) -> Config {
47    let mut config = Config::new();
48
49    let host = decode(url.host_str().unwrap_or("localhost"))?.into_owned();
50    let hosts: Vec<&str> = host.split('\\').collect();
51    match hosts.len() {
52        1 => config.host(host),
53        2 => {
54            // SQL Server support instance name: `server\instance:port`
55            config.host(hosts[0]);
56            config.instance_name(hosts[1]);
57        }
58        _ => throw!(anyhow!("MsSQL hostname parse error: {}", host)),
59    }
60    config.port(url.port().unwrap_or(1433));
61    // remove the leading "/"
62    config.database(decode(&url.path()[1..])?.to_owned());
63    // Using SQL Server authentication.
64    #[allow(unused)]
65    let params: HashMap<String, String> = url.query_pairs().into_owned().collect();
66    #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
67    match params.get("trusted_connection") {
68        // pefer trusted_connection if set to true
69        Some(v) if v == "true" => {
70            debug!("mssql auth through trusted connection!");
71            config.authentication(AuthMethod::Integrated);
72        }
73        _ => {
74            debug!("mssql auth through sqlserver authentication");
75            config.authentication(AuthMethod::sql_server(
76                decode(url.username())?.to_owned(),
77                decode(url.password().unwrap_or(""))?.to_owned(),
78            ));
79        }
80    };
81    #[cfg(all(not(windows), not(feature = "integrated-auth-gssapi")))]
82    config.authentication(AuthMethod::sql_server(
83        decode(url.username())?.to_owned(),
84        decode(url.password().unwrap_or(""))?.to_owned(),
85    ));
86
87    match params.get("trust_server_certificate") {
88        Some(v) if v.to_lowercase() == "true" => config.trust_cert(),
89        _ => {}
90    };
91
92    match params.get("trust_server_certificate_ca") {
93        Some(v) => config.trust_cert_ca(v),
94        _ => {}
95    };
96
97    match params.get("encrypt") {
98        Some(v) if v.to_lowercase() == "true" => config.encryption(EncryptionLevel::Required),
99        Some(v) if v.to_lowercase() == "false" => config.encryption(EncryptionLevel::Off),
100        _ => config.encryption(EncryptionLevel::NotSupported),
101    };
102
103    match params.get("appname") {
104        Some(appname) => config.application_name(decode(appname)?.to_owned()),
105        _ => {}
106    };
107
108    config
109}
110
111impl MsSQLSource {
112    #[throws(MsSQLSourceError)]
113    pub fn new(rt: Arc<Runtime>, conn: &str, nconn: usize) -> Self {
114        let url = Url::parse(conn)?;
115        let config = mssql_config(&url)?;
116        let manager = bb8_tiberius::ConnectionManager::new(config);
117        let pool = rt.block_on(Pool::builder().max_size(nconn as u32).build(manager))?;
118
119        Self {
120            rt,
121            pool,
122            origin_query: None,
123            queries: vec![],
124            names: vec![],
125            schema: vec![],
126        }
127    }
128}
129
130impl Source for MsSQLSource
131where
132    MsSQLSourcePartition: SourcePartition<TypeSystem = MsSQLTypeSystem, Error = MsSQLSourceError>,
133{
134    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
135    type Partition = MsSQLSourcePartition;
136    type TypeSystem = MsSQLTypeSystem;
137    type Error = MsSQLSourceError;
138
139    #[throws(MsSQLSourceError)]
140    fn set_data_order(&mut self, data_order: DataOrder) {
141        if !matches!(data_order, DataOrder::RowMajor) {
142            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
143        }
144    }
145
146    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
147        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
148    }
149
150    fn set_origin_query(&mut self, query: Option<String>) {
151        self.origin_query = query;
152    }
153
154    #[throws(MsSQLSourceError)]
155    fn fetch_metadata(&mut self) {
156        assert!(!self.queries.is_empty());
157
158        let mut conn = self.rt.block_on(self.pool.get())?;
159        let first_query = &self.queries[0];
160        let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
161            Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
162                Ok(Some(columns)) => columns
163                    .iter()
164                    .map(|col| {
165                        (
166                            col.name().to_string(),
167                            MsSQLTypeSystem::from(&col.column_type()),
168                        )
169                    })
170                    .unzip(),
171                Ok(None) => {
172                    throw!(anyhow!(
173                        "MsSQL returned no columns for query: {}",
174                        first_query
175                    ));
176                }
177                Err(e) => {
178                    throw!(anyhow!("Error fetching columns: {}", e));
179                }
180            },
181            Err(e) => {
182                debug!(
183                    "cannot get metadata for '{}', try next query: {}",
184                    first_query, e
185                );
186                throw!(e);
187            }
188        };
189
190        self.names = names;
191        self.schema = types;
192    }
193
194    #[throws(MsSQLSourceError)]
195    fn result_rows(&mut self) -> Option<usize> {
196        match &self.origin_query {
197            Some(q) => {
198                let cxq = CXQuery::Naked(q.clone());
199                let cquery = count_query(&cxq, &MsSqlDialect {})?;
200                let mut conn = self.rt.block_on(self.pool.get())?;
201
202                let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
203                let row = self
204                    .rt
205                    .block_on(stream.into_row())?
206                    .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", q))?;
207
208                let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; // the count in mssql is i32
209                Some(row as usize)
210            }
211            None => None,
212        }
213    }
214
215    fn names(&self) -> Vec<String> {
216        self.names.clone()
217    }
218
219    fn schema(&self) -> Vec<Self::TypeSystem> {
220        self.schema.clone()
221    }
222
223    #[throws(MsSQLSourceError)]
224    fn partition(self) -> Vec<Self::Partition> {
225        let mut ret = vec![];
226        for query in self.queries {
227            ret.push(MsSQLSourcePartition::new(
228                self.pool.clone(),
229                self.rt.clone(),
230                &query,
231                &self.schema,
232            ));
233        }
234        ret
235    }
236}
237
238pub struct MsSQLSourcePartition {
239    pool: Pool<ConnectionManager>,
240    rt: Arc<Runtime>,
241    query: CXQuery<String>,
242    schema: Vec<MsSQLTypeSystem>,
243    nrows: usize,
244    ncols: usize,
245}
246
247impl MsSQLSourcePartition {
248    pub fn new(
249        pool: Pool<ConnectionManager>,
250        handle: Arc<Runtime>,
251        query: &CXQuery<String>,
252        schema: &[MsSQLTypeSystem],
253    ) -> Self {
254        Self {
255            rt: handle,
256            pool,
257            query: query.clone(),
258            schema: schema.to_vec(),
259            nrows: 0,
260            ncols: schema.len(),
261        }
262    }
263}
264
265impl SourcePartition for MsSQLSourcePartition {
266    type TypeSystem = MsSQLTypeSystem;
267    type Parser<'a> = MsSQLSourceParser<'a>;
268    type Error = MsSQLSourceError;
269
270    #[throws(MsSQLSourceError)]
271    fn result_rows(&mut self) {
272        let cquery = count_query(&self.query, &MsSqlDialect {})?;
273        let mut conn = self.rt.block_on(self.pool.get())?;
274
275        let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
276        let row = self
277            .rt
278            .block_on(stream.into_row())?
279            .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", self.query))?;
280
281        let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; // the count in mssql is i32
282        self.nrows = row as usize;
283    }
284
285    #[throws(MsSQLSourceError)]
286    fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
287        let conn = self.rt.block_on(self.pool.get())?;
288        let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
289            OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
290                let conn = &mut *(conn as *mut Conn<'a>);
291
292                DummyBox(
293                    self.rt
294                        .block_on(conn.query(self.query.as_str(), &[]))
295                        .unwrap(),
296                )
297            });
298
299        MsSQLSourceParser::new(self.rt.handle(), rows, &self.schema)
300    }
301
302    fn nrows(&self) -> usize {
303        self.nrows
304    }
305
306    fn ncols(&self) -> usize {
307        self.ncols
308    }
309}
310
311pub struct MsSQLSourceParser<'a> {
312    rt: &'a Handle,
313    iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
314    rowbuf: Vec<Row>,
315    ncols: usize,
316    current_col: usize,
317    current_row: usize,
318    is_finished: bool,
319}
320
321impl<'a> MsSQLSourceParser<'a> {
322    fn new(
323        rt: &'a Handle,
324        iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
325        schema: &[MsSQLTypeSystem],
326    ) -> Self {
327        Self {
328            rt,
329            iter,
330            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
331            ncols: schema.len(),
332            current_row: 0,
333            current_col: 0,
334            is_finished: false,
335        }
336    }
337
338    #[throws(MsSQLSourceError)]
339    fn next_loc(&mut self) -> (usize, usize) {
340        let ret = (self.current_row, self.current_col);
341        self.current_row += (self.current_col + 1) / self.ncols;
342        self.current_col = (self.current_col + 1) % self.ncols;
343        ret
344    }
345}
346
347impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {
348    type TypeSystem = MsSQLTypeSystem;
349    type Error = MsSQLSourceError;
350
351    #[throws(MsSQLSourceError)]
352    fn fetch_next(&mut self) -> (usize, bool) {
353        assert!(self.current_col == 0);
354        let remaining_rows = self.rowbuf.len() - self.current_row;
355        if remaining_rows > 0 {
356            return (remaining_rows, self.is_finished);
357        } else if self.is_finished {
358            return (0, self.is_finished);
359        }
360
361        if !self.rowbuf.is_empty() {
362            self.rowbuf.drain(..);
363        }
364
365        for _ in 0..DB_BUFFER_SIZE {
366            if let Some(item) = self.rt.block_on(self.iter.next()) {
367                match item.map_err(MsSQLSourceError::MsSQLError)? {
368                    QueryItem::Row(row) => self.rowbuf.push(row),
369                    _ => continue,
370                }
371            } else {
372                self.is_finished = true;
373                break;
374            }
375        }
376        self.current_row = 0;
377        self.current_col = 0;
378        (self.rowbuf.len(), self.is_finished)
379    }
380}
381
382macro_rules! impl_produce {
383    ($($t: ty,)+) => {
384        $(
385            impl<'r, 'a> Produce<'r, $t> for MsSQLSourceParser<'a> {
386                type Error = MsSQLSourceError;
387
388                #[throws(MsSQLSourceError)]
389                fn produce(&'r mut self) -> $t {
390                    let (ridx, cidx) = self.next_loc()?;
391                    let res = self.rowbuf[ridx].get(cidx).ok_or_else(|| anyhow!("MsSQL get None at position: ({}, {})", ridx, cidx))?;
392                    res
393                }
394            }
395
396            impl<'r, 'a> Produce<'r, Option<$t>> for MsSQLSourceParser<'a> {
397                type Error = MsSQLSourceError;
398
399                #[throws(MsSQLSourceError)]
400                fn produce(&'r mut self) -> Option<$t> {
401                    let (ridx, cidx) = self.next_loc()?;
402                    let res = self.rowbuf[ridx].get(cidx);
403                    res
404                }
405            }
406        )+
407    };
408}
409
410impl_produce!(
411    u8,
412    i16,
413    i32,
414    i64,
415    IntN,
416    f32,
417    f64,
418    FloatN,
419    bool,
420    &'r str,
421    &'r [u8],
422    Uuid,
423    Decimal,
424    NaiveDateTime,
425    NaiveDate,
426    NaiveTime,
427    DateTime<Utc>,
428);