connectorx/sources/mssql/
mod.rs1mod errors;
4mod typesystem;
5
6pub use self::errors::MsSQLSourceError;
7pub use self::typesystem::{FloatN, IntN, MsSQLTypeSystem};
8use crate::constants::DB_BUFFER_SIZE;
9use crate::{
10 data_order::DataOrder,
11 errors::ConnectorXError,
12 sources::{PartitionParser, Produce, Source, SourcePartition},
13 sql::{count_query, CXQuery},
14 utils::DummyBox,
15};
16use anyhow::anyhow;
17use bb8::{Pool, PooledConnection};
18use bb8_tiberius::ConnectionManager;
19use chrono::{DateTime, Utc};
20use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
21use fehler::{throw, throws};
22use futures::StreamExt;
23use log::debug;
24use owning_ref::OwningHandle;
25use rust_decimal::Decimal;
26use sqlparser::dialect::MsSqlDialect;
27use std::collections::HashMap;
28use std::sync::Arc;
29use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
30use tokio::runtime::{Handle, Runtime};
31use url::Url;
32use urlencoding::decode;
33use uuid_old::Uuid;
34
35type Conn<'a> = PooledConnection<'a, ConnectionManager>;
36pub struct MsSQLSource {
37 rt: Arc<Runtime>,
38 pool: Pool<ConnectionManager>,
39 origin_query: Option<String>,
40 queries: Vec<CXQuery<String>>,
41 names: Vec<String>,
42 schema: Vec<MsSQLTypeSystem>,
43}
44
45#[throws(MsSQLSourceError)]
46pub fn mssql_config(url: &Url) -> Config {
47 let mut config = Config::new();
48
49 let host = decode(url.host_str().unwrap_or("localhost"))?.into_owned();
50 let hosts: Vec<&str> = host.split('\\').collect();
51 match hosts.len() {
52 1 => config.host(host),
53 2 => {
54 config.host(hosts[0]);
56 config.instance_name(hosts[1]);
57 }
58 _ => throw!(anyhow!("MsSQL hostname parse error: {}", host)),
59 }
60 config.port(url.port().unwrap_or(1433));
61 config.database(decode(&url.path()[1..])?.to_owned());
63 #[allow(unused)]
65 let params: HashMap<String, String> = url.query_pairs().into_owned().collect();
66 #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
67 match params.get("trusted_connection") {
68 Some(v) if v == "true" => {
70 debug!("mssql auth through trusted connection!");
71 config.authentication(AuthMethod::Integrated);
72 }
73 _ => {
74 debug!("mssql auth through sqlserver authentication");
75 config.authentication(AuthMethod::sql_server(
76 decode(url.username())?.to_owned(),
77 decode(url.password().unwrap_or(""))?.to_owned(),
78 ));
79 }
80 };
81 #[cfg(all(not(windows), not(feature = "integrated-auth-gssapi")))]
82 config.authentication(AuthMethod::sql_server(
83 decode(url.username())?.to_owned(),
84 decode(url.password().unwrap_or(""))?.to_owned(),
85 ));
86
87 match params.get("trust_server_certificate") {
88 Some(v) if v.to_lowercase() == "true" => config.trust_cert(),
89 _ => {}
90 };
91
92 match params.get("trust_server_certificate_ca") {
93 Some(v) => config.trust_cert_ca(v),
94 _ => {}
95 };
96
97 match params.get("encrypt") {
98 Some(v) if v.to_lowercase() == "true" => config.encryption(EncryptionLevel::Required),
99 _ => config.encryption(EncryptionLevel::NotSupported),
100 };
101
102 match params.get("appname") {
103 Some(appname) => config.application_name(decode(appname)?.to_owned()),
104 _ => {}
105 };
106
107 config
108}
109
110impl MsSQLSource {
111 #[throws(MsSQLSourceError)]
112 pub fn new(rt: Arc<Runtime>, conn: &str, nconn: usize) -> Self {
113 let url = Url::parse(conn)?;
114 let config = mssql_config(&url)?;
115 let manager = bb8_tiberius::ConnectionManager::new(config);
116 let pool = rt.block_on(Pool::builder().max_size(nconn as u32).build(manager))?;
117
118 Self {
119 rt,
120 pool,
121 origin_query: None,
122 queries: vec![],
123 names: vec![],
124 schema: vec![],
125 }
126 }
127}
128
129impl Source for MsSQLSource
130where
131 MsSQLSourcePartition: SourcePartition<TypeSystem = MsSQLTypeSystem, Error = MsSQLSourceError>,
132{
133 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
134 type Partition = MsSQLSourcePartition;
135 type TypeSystem = MsSQLTypeSystem;
136 type Error = MsSQLSourceError;
137
138 #[throws(MsSQLSourceError)]
139 fn set_data_order(&mut self, data_order: DataOrder) {
140 if !matches!(data_order, DataOrder::RowMajor) {
141 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
142 }
143 }
144
145 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
146 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
147 }
148
149 fn set_origin_query(&mut self, query: Option<String>) {
150 self.origin_query = query;
151 }
152
153 #[throws(MsSQLSourceError)]
154 fn fetch_metadata(&mut self) {
155 assert!(!self.queries.is_empty());
156
157 let mut conn = self.rt.block_on(self.pool.get())?;
158 let first_query = &self.queries[0];
159 let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
160 Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
161 Ok(Some(columns)) => columns
162 .iter()
163 .map(|col| {
164 (
165 col.name().to_string(),
166 MsSQLTypeSystem::from(&col.column_type()),
167 )
168 })
169 .unzip(),
170 Ok(None) => {
171 throw!(anyhow!(
172 "MsSQL returned no columns for query: {}",
173 first_query
174 ));
175 }
176 Err(e) => {
177 throw!(anyhow!("Error fetching columns: {}", e));
178 }
179 },
180 Err(e) => {
181 debug!(
182 "cannot get metadata for '{}', try next query: {}",
183 first_query, e
184 );
185 throw!(e);
186 }
187 };
188
189 self.names = names;
190 self.schema = types;
191 }
192
193 #[throws(MsSQLSourceError)]
194 fn result_rows(&mut self) -> Option<usize> {
195 match &self.origin_query {
196 Some(q) => {
197 let cxq = CXQuery::Naked(q.clone());
198 let cquery = count_query(&cxq, &MsSqlDialect {})?;
199 let mut conn = self.rt.block_on(self.pool.get())?;
200
201 let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
202 let row = self
203 .rt
204 .block_on(stream.into_row())?
205 .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", q))?;
206
207 let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; Some(row as usize)
209 }
210 None => None,
211 }
212 }
213
214 fn names(&self) -> Vec<String> {
215 self.names.clone()
216 }
217
218 fn schema(&self) -> Vec<Self::TypeSystem> {
219 self.schema.clone()
220 }
221
222 #[throws(MsSQLSourceError)]
223 fn partition(self) -> Vec<Self::Partition> {
224 let mut ret = vec![];
225 for query in self.queries {
226 ret.push(MsSQLSourcePartition::new(
227 self.pool.clone(),
228 self.rt.clone(),
229 &query,
230 &self.schema,
231 ));
232 }
233 ret
234 }
235}
236
237pub struct MsSQLSourcePartition {
238 pool: Pool<ConnectionManager>,
239 rt: Arc<Runtime>,
240 query: CXQuery<String>,
241 schema: Vec<MsSQLTypeSystem>,
242 nrows: usize,
243 ncols: usize,
244}
245
246impl MsSQLSourcePartition {
247 pub fn new(
248 pool: Pool<ConnectionManager>,
249 handle: Arc<Runtime>,
250 query: &CXQuery<String>,
251 schema: &[MsSQLTypeSystem],
252 ) -> Self {
253 Self {
254 rt: handle,
255 pool,
256 query: query.clone(),
257 schema: schema.to_vec(),
258 nrows: 0,
259 ncols: schema.len(),
260 }
261 }
262}
263
264impl SourcePartition for MsSQLSourcePartition {
265 type TypeSystem = MsSQLTypeSystem;
266 type Parser<'a> = MsSQLSourceParser<'a>;
267 type Error = MsSQLSourceError;
268
269 #[throws(MsSQLSourceError)]
270 fn result_rows(&mut self) {
271 let cquery = count_query(&self.query, &MsSqlDialect {})?;
272 let mut conn = self.rt.block_on(self.pool.get())?;
273
274 let stream = self.rt.block_on(conn.query(cquery.as_str(), &[]))?;
275 let row = self
276 .rt
277 .block_on(stream.into_row())?
278 .ok_or_else(|| anyhow!("MsSQL failed to get the count of query: {}", self.query))?;
279
280 let row: i32 = row.get(0).ok_or(MsSQLSourceError::GetNRowsFailed)?; self.nrows = row as usize;
282 }
283
284 #[throws(MsSQLSourceError)]
285 fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
286 let conn = self.rt.block_on(self.pool.get())?;
287 let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
288 OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
289 let conn = &mut *(conn as *mut Conn<'a>);
290
291 DummyBox(
292 self.rt
293 .block_on(conn.query(self.query.as_str(), &[]))
294 .unwrap(),
295 )
296 });
297
298 MsSQLSourceParser::new(self.rt.handle(), rows, &self.schema)
299 }
300
301 fn nrows(&self) -> usize {
302 self.nrows
303 }
304
305 fn ncols(&self) -> usize {
306 self.ncols
307 }
308}
309
310pub struct MsSQLSourceParser<'a> {
311 rt: &'a Handle,
312 iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
313 rowbuf: Vec<Row>,
314 ncols: usize,
315 current_col: usize,
316 current_row: usize,
317 is_finished: bool,
318}
319
320impl<'a> MsSQLSourceParser<'a> {
321 fn new(
322 rt: &'a Handle,
323 iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
324 schema: &[MsSQLTypeSystem],
325 ) -> Self {
326 Self {
327 rt,
328 iter,
329 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
330 ncols: schema.len(),
331 current_row: 0,
332 current_col: 0,
333 is_finished: false,
334 }
335 }
336
337 #[throws(MsSQLSourceError)]
338 fn next_loc(&mut self) -> (usize, usize) {
339 let ret = (self.current_row, self.current_col);
340 self.current_row += (self.current_col + 1) / self.ncols;
341 self.current_col = (self.current_col + 1) % self.ncols;
342 ret
343 }
344}
345
346impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {
347 type TypeSystem = MsSQLTypeSystem;
348 type Error = MsSQLSourceError;
349
350 #[throws(MsSQLSourceError)]
351 fn fetch_next(&mut self) -> (usize, bool) {
352 assert!(self.current_col == 0);
353 let remaining_rows = self.rowbuf.len() - self.current_row;
354 if remaining_rows > 0 {
355 return (remaining_rows, self.is_finished);
356 } else if self.is_finished {
357 return (0, self.is_finished);
358 }
359
360 if !self.rowbuf.is_empty() {
361 self.rowbuf.drain(..);
362 }
363
364 for _ in 0..DB_BUFFER_SIZE {
365 if let Some(item) = self.rt.block_on(self.iter.next()) {
366 match item.map_err(MsSQLSourceError::MsSQLError)? {
367 QueryItem::Row(row) => self.rowbuf.push(row),
368 _ => continue,
369 }
370 } else {
371 self.is_finished = true;
372 break;
373 }
374 }
375 self.current_row = 0;
376 self.current_col = 0;
377 (self.rowbuf.len(), self.is_finished)
378 }
379}
380
381macro_rules! impl_produce {
382 ($($t: ty,)+) => {
383 $(
384 impl<'r, 'a> Produce<'r, $t> for MsSQLSourceParser<'a> {
385 type Error = MsSQLSourceError;
386
387 #[throws(MsSQLSourceError)]
388 fn produce(&'r mut self) -> $t {
389 let (ridx, cidx) = self.next_loc()?;
390 let res = self.rowbuf[ridx].get(cidx).ok_or_else(|| anyhow!("MsSQL get None at position: ({}, {})", ridx, cidx))?;
391 res
392 }
393 }
394
395 impl<'r, 'a> Produce<'r, Option<$t>> for MsSQLSourceParser<'a> {
396 type Error = MsSQLSourceError;
397
398 #[throws(MsSQLSourceError)]
399 fn produce(&'r mut self) -> Option<$t> {
400 let (ridx, cidx) = self.next_loc()?;
401 let res = self.rowbuf[ridx].get(cidx);
402 res
403 }
404 }
405 )+
406 };
407}
408
409impl_produce!(
410 u8,
411 i16,
412 i32,
413 i64,
414 IntN,
415 f32,
416 f64,
417 FloatN,
418 bool,
419 &'r str,
420 &'r [u8],
421 Uuid,
422 Decimal,
423 NaiveDateTime,
424 NaiveDate,
425 NaiveTime,
426 DateTime<Utc>,
427);