connectorx/sources/mssql/
mod.rs

1//! Source implementation for SQL Server.
2
3mod errors;
4mod typesystem;
5
6pub use self::errors::MsSQLSourceError;
7pub use self::typesystem::{FloatN, IntN, MsSQLTypeSystem};
8use crate::constants::DB_BUFFER_SIZE;
9use crate::{
10    data_order::DataOrder,
11    errors::ConnectorXError,
12    sources::{PartitionParser, Produce, Source, SourcePartition},
13    sql::{count_query, CXQuery},
14    utils::DummyBox,
15};
16use anyhow::anyhow;
17use bb8::{Pool, PooledConnection};
18use bb8_tiberius::ConnectionManager;
19use chrono::{DateTime, Utc};
20use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
21use fehler::{throw, throws};
22use futures::StreamExt;
23use log::debug;
24use owning_ref::OwningHandle;
25use rust_decimal::Decimal;
26use sqlparser::dialect::MsSqlDialect;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
30use tokio::runtime::{Handle, Runtime};
31use url::Url;
32use urlencoding::decode;
33use uuid_old::Uuid;
34
35type Conn<'a> = PooledConnection<'a, ConnectionManager>;
36pub struct MsSQLSource {
37    rt: Arc<Runtime>,
38    pool: Pool<ConnectionManager>,
39    origin_query: Option<String>,
40    queries: Vec<CXQuery<String>>,
41    names: Vec<String>,
42    schema: Vec<MsSQLTypeSystem>,
43}
44
45#[throws(MsSQLSourceError)]
46pub fn mssql_config(url: &Url) -> Config {
47    let mut config = Config::new();
48
49    let host = decode(url.host_str().unwrap_or("localhost"))?.into_owned();
50    let hosts: Vec<&str> = host.split('\\').collect();
51    match hosts.len() {
52        1 => config.host(host),
53        2 => {
54            // SQL Server support instance name: `server\instance:port`
55            config.host(hosts[0]);
56            config.instance_name(hosts[1]);
57        }
58        _ => throw!(anyhow!("MsSQL hostname parse error: {}", host)),
59    }
60    config.port(url.port().unwrap_or(1433));
61    // remove the leading "/"
62    config.database(decode(&url.path()[1..])?.to_owned());
63    // Using SQL Server authentication.
64    #[allow(unused)]
65    let params: HashMap<String, String> = url.query_pairs().into_owned().collect();
66    #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
67    match params.get("trusted_connection") {
68        // pefer trusted_connection if set to true
69        Some(v) if v == "true" => {
70            debug!("mssql auth through trusted connection!");
71            config.authentication(AuthMethod::Integrated);
72        }
73        _ => {
74            debug!("mssql auth through sqlserver authentication");
75            config.authentication(AuthMethod::sql_server(
76                decode(url.username())?.to_owned(),
77                decode(url.password().unwrap_or(""))?.to_owned(),
78            ));
79        }
80    };
81    #[cfg(all(not(windows), not(feature = "integrated-auth-gssapi")))]
82    config.authentication(AuthMethod::sql_server(
83        decode(url.username())?.to_owned(),
84        decode(url.password().unwrap_or(""))?.to_owned(),
85    ));
86
87    match params.get("trust_server_certificate") {
88        Some(v) if v.to_lowercase() == "true" => config.trust_cert(),
89        _ => {}
90    };
91
92    match params.get("trust_server_certificate_ca") {
93        Some(v) => config.trust_cert_ca(v),
94        _ => {}
95    };
96
97    match params.get("encrypt") {
98        Some(v) if v.to_lowercase() == "true" => config.encryption(EncryptionLevel::Required),
99        _ => config.encryption(EncryptionLevel::NotSupported),
100    };
101
102    match params.get("appname") {
103        Some(appname) => config.application_name(decode(appname)?.to_owned()),
104        _ => {}
105    };
106
107    config
108}
109
110impl MsSQLSource {
111    #[throws(MsSQLSourceError)]
112    pub fn new(rt: Arc<Runtime>, conn: &str, nconn: usize) -> Self {
113        let url = Url::parse(conn)?;
114        let config = mssql_config(&url)?;
115        let manager = bb8_tiberius::ConnectionManager::new(config);
116        let pool = rt.block_on(Pool::builder().max_size(nconn as u32).build(manager))?;
117
118        Self {
119            rt,
120            pool,
121            origin_query: None,
122            queries: vec![],
123            names: vec![],
124            schema: vec![],
125        }
126    }
127}
128
129impl Source for MsSQLSource
130where
131    MsSQLSourcePartition: SourcePartition<TypeSystem = MsSQLTypeSystem, Error = MsSQLSourceError>,
132{
133    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
134    type Partition = MsSQLSourcePartition;
135    type TypeSystem = MsSQLTypeSystem;
136    type Error = MsSQLSourceError;
137
138    #[throws(MsSQLSourceError)]
139    fn set_data_order(&mut self, data_order: DataOrder) {
140        if !matches!(data_order, DataOrder::RowMajor) {
141            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
142        }
143    }
144
145    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
146        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
147    }
148
149    fn set_origin_query(&mut self, query: Option<String>) {
150        self.origin_query = query;
151    }
152
153    #[throws(MsSQLSourceError)]
154    fn fetch_metadata(&mut self) {
155        assert!(!self.queries.is_empty());
156
157        let mut conn = self.rt.block_on(self.pool.get())?;
158        let first_query = &self.queries[0];
159        let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
160            Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
161                Ok(Some(columns)) => columns
162                    .iter()
163                    .map(|col| {
164                        (
165                            col.name().to_string(),
166                            MsSQLTypeSystem::from(&col.column_type()),
167                        )
168                    })
169                    .unzip(),
170                Ok(None) => {
171                    throw!(anyhow!(
172                        "MsSQL returned no columns for query: {}",
173                        first_query
174                    ));
175                }
176                Err(e) => {
177                    throw!(anyhow!("Error fetching columns: {}", e));
178                }
179            },
180            Err(e) => {
181                debug!(
182                    "cannot get metadata for '{}', try next query: {}",
183                    first_query, e
184                );
185                throw!(e);
186            }
187        };
188
189        self.names = names;
190        self.schema = types;
191    }
192
193    #[throws(MsSQLSourceError)]
194    fn result_rows(&mut self) -> Option<usize> {
195        match &self.origin_query {
196            Some(q) => {
197                let cxq = CXQuery::Naked(q.clone());
198                let cquery = count_query(&cxq, &MsSqlDialect {})?;
199                let mut conn = self.rt.block_on(self.pool.get())?;
200
201                let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
202                let row = self
203                    .rt
204                    .block_on(stream.into_row())?
205                    .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", q))?;
206
207                let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; // the count in mssql is i32
208                Some(row as usize)
209            }
210            None => None,
211        }
212    }
213
214    fn names(&self) -> Vec<String> {
215        self.names.clone()
216    }
217
218    fn schema(&self) -> Vec<Self::TypeSystem> {
219        self.schema.clone()
220    }
221
222    #[throws(MsSQLSourceError)]
223    fn partition(self) -> Vec<Self::Partition> {
224        let mut ret = vec![];
225        for query in self.queries {
226            ret.push(MsSQLSourcePartition::new(
227                self.pool.clone(),
228                self.rt.clone(),
229                &query,
230                &self.schema,
231            ));
232        }
233        ret
234    }
235}
236
237pub struct MsSQLSourcePartition {
238    pool: Pool<ConnectionManager>,
239    rt: Arc<Runtime>,
240    query: CXQuery<String>,
241    schema: Vec<MsSQLTypeSystem>,
242    nrows: usize,
243    ncols: usize,
244}
245
246impl MsSQLSourcePartition {
247    pub fn new(
248        pool: Pool<ConnectionManager>,
249        handle: Arc<Runtime>,
250        query: &CXQuery<String>,
251        schema: &[MsSQLTypeSystem],
252    ) -> Self {
253        Self {
254            rt: handle,
255            pool,
256            query: query.clone(),
257            schema: schema.to_vec(),
258            nrows: 0,
259            ncols: schema.len(),
260        }
261    }
262}
263
264impl SourcePartition for MsSQLSourcePartition {
265    type TypeSystem = MsSQLTypeSystem;
266    type Parser<'a> = MsSQLSourceParser<'a>;
267    type Error = MsSQLSourceError;
268
269    #[throws(MsSQLSourceError)]
270    fn result_rows(&mut self) {
271        let cquery = count_query(&self.query, &MsSqlDialect {})?;
272        let mut conn = self.rt.block_on(self.pool.get())?;
273
274        let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
275        let row = self
276            .rt
277            .block_on(stream.into_row())?
278            .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", self.query))?;
279
280        let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; // the count in mssql is i32
281        self.nrows = row as usize;
282    }
283
284    #[throws(MsSQLSourceError)]
285    fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
286        let conn = self.rt.block_on(self.pool.get())?;
287        let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
288            OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
289                let conn = &mut *(conn as *mut Conn<'a>);
290
291                DummyBox(
292                    self.rt
293                        .block_on(conn.query(self.query.as_str(), &[]))
294                        .unwrap(),
295                )
296            });
297
298        MsSQLSourceParser::new(self.rt.handle(), rows, &self.schema)
299    }
300
301    fn nrows(&self) -> usize {
302        self.nrows
303    }
304
305    fn ncols(&self) -> usize {
306        self.ncols
307    }
308}
309
310pub struct MsSQLSourceParser<'a> {
311    rt: &'a Handle,
312    iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
313    rowbuf: Vec<Row>,
314    ncols: usize,
315    current_col: usize,
316    current_row: usize,
317    is_finished: bool,
318}
319
320impl<'a> MsSQLSourceParser<'a> {
321    fn new(
322        rt: &'a Handle,
323        iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
324        schema: &[MsSQLTypeSystem],
325    ) -> Self {
326        Self {
327            rt,
328            iter,
329            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
330            ncols: schema.len(),
331            current_row: 0,
332            current_col: 0,
333            is_finished: false,
334        }
335    }
336
337    #[throws(MsSQLSourceError)]
338    fn next_loc(&mut self) -> (usize, usize) {
339        let ret = (self.current_row, self.current_col);
340        self.current_row += (self.current_col + 1) / self.ncols;
341        self.current_col = (self.current_col + 1) % self.ncols;
342        ret
343    }
344}
345
346impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {
347    type TypeSystem = MsSQLTypeSystem;
348    type Error = MsSQLSourceError;
349
350    #[throws(MsSQLSourceError)]
351    fn fetch_next(&mut self) -> (usize, bool) {
352        assert!(self.current_col == 0);
353        let remaining_rows = self.rowbuf.len() - self.current_row;
354        if remaining_rows > 0 {
355            return (remaining_rows, self.is_finished);
356        } else if self.is_finished {
357            return (0, self.is_finished);
358        }
359
360        if !self.rowbuf.is_empty() {
361            self.rowbuf.drain(..);
362        }
363
364        for _ in 0..DB_BUFFER_SIZE {
365            if let Some(item) = self.rt.block_on(self.iter.next()) {
366                match item.map_err(MsSQLSourceError::MsSQLError)? {
367                    QueryItem::Row(row) => self.rowbuf.push(row),
368                    _ => continue,
369                }
370            } else {
371                self.is_finished = true;
372                break;
373            }
374        }
375        self.current_row = 0;
376        self.current_col = 0;
377        (self.rowbuf.len(), self.is_finished)
378    }
379}
380
381macro_rules! impl_produce {
382    ($($t: ty,)+) => {
383        $(
384            impl<'r, 'a> Produce<'r, $t> for MsSQLSourceParser<'a> {
385                type Error = MsSQLSourceError;
386
387                #[throws(MsSQLSourceError)]
388                fn produce(&'r mut self) -> $t {
389                    let (ridx, cidx) = self.next_loc()?;
390                    let res = self.rowbuf[ridx].get(cidx).ok_or_else(|| anyhow!("MsSQL get None at position: ({}, {})", ridx, cidx))?;
391                    res
392                }
393            }
394
395            impl<'r, 'a> Produce<'r, Option<$t>> for MsSQLSourceParser<'a> {
396                type Error = MsSQLSourceError;
397
398                #[throws(MsSQLSourceError)]
399                fn produce(&'r mut self) -> Option<$t> {
400                    let (ridx, cidx) = self.next_loc()?;
401                    let res = self.rowbuf[ridx].get(cidx);
402                    res
403                }
404            }
405        )+
406    };
407}
408
409impl_produce!(
410    u8,
411    i16,
412    i32,
413    i64,
414    IntN,
415    f32,
416    f64,
417    FloatN,
418    bool,
419    &'r str,
420    &'r [u8],
421    Uuid,
422    Decimal,
423    NaiveDateTime,
424    NaiveDate,
425    NaiveTime,
426    DateTime<Utc>,
427);