connectorx/sources/mssql/
mod.rs1mod errors;
4mod typesystem;
5
6pub use self::errors::MsSQLSourceError;
7pub use self::typesystem::{FloatN, IntN, MsSQLTypeSystem};
8use crate::constants::DB_BUFFER_SIZE;
9use crate::{
10 data_order::DataOrder,
11 errors::ConnectorXError,
12 sources::{PartitionParser, Produce, Source, SourcePartition},
13 sql::{count_query, CXQuery},
14 utils::DummyBox,
15};
16use anyhow::anyhow;
17use bb8::{Pool, PooledConnection};
18use bb8_tiberius::ConnectionManager;
19use chrono::{DateTime, Utc};
20use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
21use fehler::{throw, throws};
22use futures::StreamExt;
23use log::debug;
24use owning_ref::OwningHandle;
25use rust_decimal::Decimal;
26use sqlparser::dialect::MsSqlDialect;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
30use tokio::runtime::{Handle, Runtime};
31use url::Url;
32use urlencoding::decode;
33use uuid_old::Uuid;
34
35type Conn<'a> = PooledConnection<'a, ConnectionManager>;
36pub struct MsSQLSource {
37 rt: Arc<Runtime>,
38 pool: Pool<ConnectionManager>,
39 origin_query: Option<String>,
40 queries: Vec<CXQuery<String>>,
41 names: Vec<String>,
42 schema: Vec<MsSQLTypeSystem>,
43}
44
45#[throws(MsSQLSourceError)]
46pub fn mssql_config(url: &Url) -> Config {
47 let mut config = Config::new();
48
49 let host = decode(url.host_str().unwrap_or("localhost"))?.into_owned();
50 let hosts: Vec<&str> = host.split('\\').collect();
51 match hosts.len() {
52 1 => config.host(host),
53 2 => {
54 config.host(hosts[0]);
56 config.instance_name(hosts[1]);
57 }
58 _ => throw!(anyhow!("MsSQL hostname parse error: {}", host)),
59 }
60 config.port(url.port().unwrap_or(1433));
61 config.database(decode(&url.path()[1..])?.to_owned());
63 #[allow(unused)]
65 let params: HashMap<String, String> = url.query_pairs().into_owned().collect();
66 #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
67 match params.get("trusted_connection") {
68 Some(v) if v == "true" => {
70 debug!("mssql auth through trusted connection!");
71 config.authentication(AuthMethod::Integrated);
72 }
73 _ => {
74 debug!("mssql auth through sqlserver authentication");
75 config.authentication(AuthMethod::sql_server(
76 decode(url.username())?.to_owned(),
77 decode(url.password().unwrap_or(""))?.to_owned(),
78 ));
79 }
80 };
81 #[cfg(all(not(windows), not(feature = "integrated-auth-gssapi")))]
82 config.authentication(AuthMethod::sql_server(
83 decode(url.username())?.to_owned(),
84 decode(url.password().unwrap_or(""))?.to_owned(),
85 ));
86
87 match params.get("trust_server_certificate") {
88 Some(v) if v.to_lowercase() == "true" => config.trust_cert(),
89 _ => {}
90 };
91
92 match params.get("trust_server_certificate_ca") {
93 Some(v) => config.trust_cert_ca(v),
94 _ => {}
95 };
96
97 match params.get("encrypt") {
98 Some(v) if v.to_lowercase() == "true" => config.encryption(EncryptionLevel::Required),
99 Some(v) if v.to_lowercase() == "false" => config.encryption(EncryptionLevel::Off),
100 _ => config.encryption(EncryptionLevel::NotSupported),
101 };
102
103 match params.get("appname") {
104 Some(appname) => config.application_name(decode(appname)?.to_owned()),
105 _ => {}
106 };
107
108 config
109}
110
111impl MsSQLSource {
112 #[throws(MsSQLSourceError)]
113 pub fn new(rt: Arc<Runtime>, conn: &str, nconn: usize) -> Self {
114 let url = Url::parse(conn)?;
115 let config = mssql_config(&url)?;
116 let manager = bb8_tiberius::ConnectionManager::new(config);
117 let pool = rt.block_on(Pool::builder().max_size(nconn as u32).build(manager))?;
118
119 Self {
120 rt,
121 pool,
122 origin_query: None,
123 queries: vec![],
124 names: vec![],
125 schema: vec![],
126 }
127 }
128}
129
130impl Source for MsSQLSource
131where
132 MsSQLSourcePartition: SourcePartition<TypeSystem = MsSQLTypeSystem, Error = MsSQLSourceError>,
133{
134 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
135 type Partition = MsSQLSourcePartition;
136 type TypeSystem = MsSQLTypeSystem;
137 type Error = MsSQLSourceError;
138
139 #[throws(MsSQLSourceError)]
140 fn set_data_order(&mut self, data_order: DataOrder) {
141 if !matches!(data_order, DataOrder::RowMajor) {
142 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
143 }
144 }
145
146 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
147 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
148 }
149
150 fn set_origin_query(&mut self, query: Option<String>) {
151 self.origin_query = query;
152 }
153
154 #[throws(MsSQLSourceError)]
155 fn fetch_metadata(&mut self) {
156 assert!(!self.queries.is_empty());
157
158 let mut conn = self.rt.block_on(self.pool.get())?;
159 let first_query = &self.queries[0];
160 let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
161 Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
162 Ok(Some(columns)) => columns
163 .iter()
164 .map(|col| {
165 (
166 col.name().to_string(),
167 MsSQLTypeSystem::from(&col.column_type()),
168 )
169 })
170 .unzip(),
171 Ok(None) => {
172 throw!(anyhow!(
173 "MsSQL returned no columns for query: {}",
174 first_query
175 ));
176 }
177 Err(e) => {
178 throw!(anyhow!("Error fetching columns: {}", e));
179 }
180 },
181 Err(e) => {
182 debug!(
183 "cannot get metadata for '{}', try next query: {}",
184 first_query, e
185 );
186 throw!(e);
187 }
188 };
189
190 self.names = names;
191 self.schema = types;
192 }
193
194 #[throws(MsSQLSourceError)]
195 fn result_rows(&mut self) -> Option<usize> {
196 match &self.origin_query {
197 Some(q) => {
198 let cxq = CXQuery::Naked(q.clone());
199 let cquery = count_query(&cxq, &MsSqlDialect {})?;
200 let mut conn = self.rt.block_on(self.pool.get())?;
201
202 let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
203 let row = self
204 .rt
205 .block_on(stream.into_row())?
206 .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", q))?;
207
208 let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; Some(row as usize)
210 }
211 None => None,
212 }
213 }
214
215 fn names(&self) -> Vec<String> {
216 self.names.clone()
217 }
218
219 fn schema(&self) -> Vec<Self::TypeSystem> {
220 self.schema.clone()
221 }
222
223 #[throws(MsSQLSourceError)]
224 fn partition(self) -> Vec<Self::Partition> {
225 let mut ret = vec![];
226 for query in self.queries {
227 ret.push(MsSQLSourcePartition::new(
228 self.pool.clone(),
229 self.rt.clone(),
230 &query,
231 &self.schema,
232 ));
233 }
234 ret
235 }
236}
237
238pub struct MsSQLSourcePartition {
239 pool: Pool<ConnectionManager>,
240 rt: Arc<Runtime>,
241 query: CXQuery<String>,
242 schema: Vec<MsSQLTypeSystem>,
243 nrows: usize,
244 ncols: usize,
245}
246
247impl MsSQLSourcePartition {
248 pub fn new(
249 pool: Pool<ConnectionManager>,
250 handle: Arc<Runtime>,
251 query: &CXQuery<String>,
252 schema: &[MsSQLTypeSystem],
253 ) -> Self {
254 Self {
255 rt: handle,
256 pool,
257 query: query.clone(),
258 schema: schema.to_vec(),
259 nrows: 0,
260 ncols: schema.len(),
261 }
262 }
263}
264
265impl SourcePartition for MsSQLSourcePartition {
266 type TypeSystem = MsSQLTypeSystem;
267 type Parser<'a> = MsSQLSourceParser<'a>;
268 type Error = MsSQLSourceError;
269
270 #[throws(MsSQLSourceError)]
271 fn result_rows(&mut self) {
272 let cquery = count_query(&self.query, &MsSqlDialect {})?;
273 let mut conn = self.rt.block_on(self.pool.get())?;
274
275 let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
276 let row = self
277 .rt
278 .block_on(stream.into_row())?
279 .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", self.query))?;
280
281 let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; self.nrows = row as usize;
283 }
284
285 #[throws(MsSQLSourceError)]
286 fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
287 let conn = self.rt.block_on(self.pool.get())?;
288 let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
289 OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
290 let conn = &mut *(conn as *mut Conn<'a>);
291
292 DummyBox(
293 self.rt
294 .block_on(conn.query(self.query.as_str(), &[]))
295 .unwrap(),
296 )
297 });
298
299 MsSQLSourceParser::new(self.rt.handle(), rows, &self.schema)
300 }
301
302 fn nrows(&self) -> usize {
303 self.nrows
304 }
305
306 fn ncols(&self) -> usize {
307 self.ncols
308 }
309}
310
311pub struct MsSQLSourceParser<'a> {
312 rt: &'a Handle,
313 iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
314 rowbuf: Vec<Row>,
315 ncols: usize,
316 current_col: usize,
317 current_row: usize,
318 is_finished: bool,
319}
320
321impl<'a> MsSQLSourceParser<'a> {
322 fn new(
323 rt: &'a Handle,
324 iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
325 schema: &[MsSQLTypeSystem],
326 ) -> Self {
327 Self {
328 rt,
329 iter,
330 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
331 ncols: schema.len(),
332 current_row: 0,
333 current_col: 0,
334 is_finished: false,
335 }
336 }
337
338 #[throws(MsSQLSourceError)]
339 fn next_loc(&mut self) -> (usize, usize) {
340 let ret = (self.current_row, self.current_col);
341 self.current_row += (self.current_col + 1) / self.ncols;
342 self.current_col = (self.current_col + 1) % self.ncols;
343 ret
344 }
345}
346
347impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {
348 type TypeSystem = MsSQLTypeSystem;
349 type Error = MsSQLSourceError;
350
351 #[throws(MsSQLSourceError)]
352 fn fetch_next(&mut self) -> (usize, bool) {
353 assert!(self.current_col == 0);
354 let remaining_rows = self.rowbuf.len() - self.current_row;
355 if remaining_rows > 0 {
356 return (remaining_rows, self.is_finished);
357 } else if self.is_finished {
358 return (0, self.is_finished);
359 }
360
361 if !self.rowbuf.is_empty() {
362 self.rowbuf.drain(..);
363 }
364
365 for _ in 0..DB_BUFFER_SIZE {
366 if let Some(item) = self.rt.block_on(self.iter.next()) {
367 match item.map_err(MsSQLSourceError::MsSQLError)? {
368 QueryItem::Row(row) => self.rowbuf.push(row),
369 _ => continue,
370 }
371 } else {
372 self.is_finished = true;
373 break;
374 }
375 }
376 self.current_row = 0;
377 self.current_col = 0;
378 (self.rowbuf.len(), self.is_finished)
379 }
380}
381
382macro_rules! impl_produce {
383 ($($t: ty,)+) => {
384 $(
385 impl<'r, 'a> Produce<'r, $t> for MsSQLSourceParser<'a> {
386 type Error = MsSQLSourceError;
387
388 #[throws(MsSQLSourceError)]
389 fn produce(&'r mut self) -> $t {
390 let (ridx, cidx) = self.next_loc()?;
391 let res = self.rowbuf[ridx].get(cidx).ok_or_else(|| anyhow!("MsSQL get None at position: ({}, {})", ridx, cidx))?;
392 res
393 }
394 }
395
396 impl<'r, 'a> Produce<'r, Option<$t>> for MsSQLSourceParser<'a> {
397 type Error = MsSQLSourceError;
398
399 #[throws(MsSQLSourceError)]
400 fn produce(&'r mut self) -> Option<$t> {
401 let (ridx, cidx) = self.next_loc()?;
402 let res = self.rowbuf[ridx].get(cidx);
403 res
404 }
405 }
406 )+
407 };
408}
409
410impl_produce!(
411 u8,
412 i16,
413 i32,
414 i64,
415 IntN,
416 f32,
417 f64,
418 FloatN,
419 bool,
420 &'r str,
421 &'r [u8],
422 Uuid,
423 Decimal,
424 NaiveDateTime,
425 NaiveDate,
426 NaiveTime,
427 DateTime<Utc>,
428);