1mod errors;
4mod typesystem;
5
6pub use self::errors::MySQLSourceError;
7use crate::constants::DB_BUFFER_SIZE;
8use crate::{
9 data_order::DataOrder,
10 errors::ConnectorXError,
11 sources::{PartitionParser, Produce, Source, SourcePartition},
12 sql::{count_query, limit1_query, CXQuery},
13};
14use anyhow::anyhow;
15use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
16use fehler::{throw, throws};
17use log::{debug, warn};
18use r2d2::{Pool, PooledConnection};
19use r2d2_mysql::{
20 mysql::{
21 consts::{
22 ColumnFlags as MySQLColumnFlags, ColumnType as MySQLColumnType, UTF8MB4_GENERAL_CI,
23 UTF8_GENERAL_CI,
24 },
25 prelude::Queryable,
26 Binary, Opts, OptsBuilder, QueryResult, Row, Text,
27 },
28 MySqlConnectionManager,
29};
30use rust_decimal::Decimal;
31use serde_json::Value;
32use sqlparser::dialect::MySqlDialect;
33use std::marker::PhantomData;
34pub use typesystem::MySQLTypeSystem;
35
36type MysqlConn = PooledConnection<MySqlConnectionManager>;
37
38pub enum BinaryProtocol {}
39pub enum TextProtocol {}
40
41#[throws(MySQLSourceError)]
42fn get_total_rows(conn: &mut MysqlConn, query: &CXQuery<String>) -> usize {
43 conn.query_first(&count_query(query, &MySqlDialect {})?)?
44 .ok_or_else(|| anyhow!("mysql failed to get the count of query: {}", query))?
45}
46
47pub struct MySQLSource<P> {
48 pool: Pool<MySqlConnectionManager>,
49 origin_query: Option<String>,
50 queries: Vec<CXQuery<String>>,
51 names: Vec<String>,
52 schema: Vec<MySQLTypeSystem>,
53 pre_execution_queries: Option<Vec<String>>,
54 _protocol: PhantomData<P>,
55}
56
57impl<P> MySQLSource<P> {
58 #[throws(MySQLSourceError)]
59 pub fn new(conn: &str, nconn: usize) -> Self {
60 let manager = MySqlConnectionManager::new(OptsBuilder::from_opts(Opts::from_url(conn)?));
61 let pool = r2d2::Pool::builder()
62 .max_size(nconn as u32)
63 .build(manager)?;
64
65 Self {
66 pool,
67 origin_query: None,
68 queries: vec![],
69 names: vec![],
70 schema: vec![],
71 pre_execution_queries: None,
72 _protocol: PhantomData,
73 }
74 }
75}
76
77impl<P> Source for MySQLSource<P>
78where
79 MySQLSourcePartition<P>:
80 SourcePartition<TypeSystem = MySQLTypeSystem, Error = MySQLSourceError>,
81 P: Send,
82{
83 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
84 type Partition = MySQLSourcePartition<P>;
85 type TypeSystem = MySQLTypeSystem;
86 type Error = MySQLSourceError;
87
88 #[throws(MySQLSourceError)]
89 fn set_data_order(&mut self, data_order: DataOrder) {
90 if !matches!(data_order, DataOrder::RowMajor) {
91 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
92 }
93 }
94
95 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
96 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
97 }
98
99 fn set_origin_query(&mut self, query: Option<String>) {
100 self.origin_query = query;
101 }
102
103 fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) {
104 self.pre_execution_queries = pre_execution_queries.map(|s| s.to_vec());
105 }
106
107 #[throws(MySQLSourceError)]
108 fn fetch_metadata(&mut self) {
109 assert!(!self.queries.is_empty());
110
111 let mut conn = self.pool.get()?;
112 let server_version_post_5_5_3 = conn.server_version() >= (5, 5, 3);
113
114 let first_query = &self.queries[0];
115
116 match conn.prep(first_query) {
117 Ok(stmt) => {
118 let (names, types) = stmt
119 .columns()
120 .iter()
121 .map(|col| {
122 let col_name = col.name_str().to_string();
123 let col_type = col.column_type();
124 let col_flags = col.flags();
125 let charset = col.character_set();
126 let charset_is_utf8 = (server_version_post_5_5_3
127 && charset == UTF8MB4_GENERAL_CI)
128 || (!server_version_post_5_5_3 && charset == UTF8_GENERAL_CI);
129 if charset_is_utf8
130 && (col_type == MySQLColumnType::MYSQL_TYPE_LONG_BLOB
131 || col_type == MySQLColumnType::MYSQL_TYPE_BLOB
132 || col_type == MySQLColumnType::MYSQL_TYPE_MEDIUM_BLOB
133 || col_type == MySQLColumnType::MYSQL_TYPE_TINY_BLOB)
134 {
135 return (
136 col_name,
137 MySQLTypeSystem::Char(
138 !col_flags.contains(MySQLColumnFlags::NOT_NULL_FLAG),
139 ),
140 );
141 }
142 let d = MySQLTypeSystem::from((&col_type, &col_flags));
143 (col_name, d)
144 })
145 .unzip();
146 self.names = names;
147 self.schema = types;
148 }
149 Err(e) => {
150 warn!(
151 "mysql text prepared statement error: {:?}, switch to limit1 method",
152 e
153 );
154 for (i, query) in self.queries.iter().enumerate() {
155 match conn
157 .query_first::<Row, _>(limit1_query(query, &MySqlDialect {})?.as_str())
158 {
159 Ok(Some(row)) => {
160 let (names, types) = row
161 .columns_ref()
162 .iter()
163 .map(|col| {
164 (
165 col.name_str().to_string(),
166 MySQLTypeSystem::from((&col.column_type(), &col.flags())),
167 )
168 })
169 .unzip();
170 self.names = names;
171 self.schema = types;
172 return;
173 }
174 Ok(None) => {}
175 Err(e) if i == self.queries.len() - 1 => {
176 debug!("cannot get metadata for '{}', try next query: {}", query, e);
178 throw!(e)
179 }
180 Err(_) => {}
181 }
182 }
183
184 let iter = conn.query_iter(self.queries[0].as_str())?;
186 let (names, types) = iter
187 .columns()
188 .as_ref()
189 .iter()
190 .map(|col| {
191 (
192 col.name_str().to_string(),
193 MySQLTypeSystem::VarChar(false), )
195 })
196 .unzip();
197 self.names = names;
198 self.schema = types;
199 }
200 }
201 }
202
203 #[throws(MySQLSourceError)]
204 fn result_rows(&mut self) -> Option<usize> {
205 match &self.origin_query {
206 Some(q) => {
207 let cxq = CXQuery::Naked(q.clone());
208 let mut conn = self.pool.get()?;
209 let nrows = get_total_rows(&mut conn, &cxq)?;
210 Some(nrows)
211 }
212 None => None,
213 }
214 }
215
216 fn names(&self) -> Vec<String> {
217 self.names.clone()
218 }
219
220 fn schema(&self) -> Vec<Self::TypeSystem> {
221 self.schema.clone()
222 }
223
224 #[throws(MySQLSourceError)]
225 fn partition(self) -> Vec<Self::Partition> {
226 let mut ret = vec![];
227 for query in self.queries {
228 let mut conn = self.pool.get()?;
229
230 if let Some(pre_queries) = &self.pre_execution_queries {
231 for pre_query in pre_queries {
232 conn.query_drop(pre_query)?;
233 }
234 }
235
236 ret.push(MySQLSourcePartition::new(conn, &query, &self.schema));
237 }
238 ret
239 }
240}
241
242pub struct MySQLSourcePartition<P> {
243 conn: MysqlConn,
244 query: CXQuery<String>,
245 schema: Vec<MySQLTypeSystem>,
246 nrows: usize,
247 ncols: usize,
248 _protocol: PhantomData<P>,
249}
250
251impl<P> MySQLSourcePartition<P> {
252 pub fn new(conn: MysqlConn, query: &CXQuery<String>, schema: &[MySQLTypeSystem]) -> Self {
253 Self {
254 conn,
255 query: query.clone(),
256 schema: schema.to_vec(),
257 nrows: 0,
258 ncols: schema.len(),
259 _protocol: PhantomData,
260 }
261 }
262}
263
264impl SourcePartition for MySQLSourcePartition<BinaryProtocol> {
265 type TypeSystem = MySQLTypeSystem;
266 type Parser<'a> = MySQLBinarySourceParser<'a>;
267 type Error = MySQLSourceError;
268
269 #[throws(MySQLSourceError)]
270 fn result_rows(&mut self) {
271 self.nrows = get_total_rows(&mut self.conn, &self.query)?;
272 }
273
274 #[throws(MySQLSourceError)]
275 fn parser(&mut self) -> Self::Parser<'_> {
276 let stmt = self.conn.prep(self.query.as_str())?;
277 let iter = self.conn.exec_iter(stmt, ())?;
278 MySQLBinarySourceParser::new(iter, &self.schema)
279 }
280
281 fn nrows(&self) -> usize {
282 self.nrows
283 }
284
285 fn ncols(&self) -> usize {
286 self.ncols
287 }
288}
289
290impl SourcePartition for MySQLSourcePartition<TextProtocol> {
291 type TypeSystem = MySQLTypeSystem;
292 type Parser<'a> = MySQLTextSourceParser<'a>;
293 type Error = MySQLSourceError;
294
295 #[throws(MySQLSourceError)]
296 fn result_rows(&mut self) {
297 self.nrows = get_total_rows(&mut self.conn, &self.query)?;
298 }
299
300 #[throws(MySQLSourceError)]
301 fn parser(&mut self) -> Self::Parser<'_> {
302 let query = self.query.clone();
303 let iter = self.conn.query_iter(query)?;
304 MySQLTextSourceParser::new(iter, &self.schema)
305 }
306
307 fn nrows(&self) -> usize {
308 self.nrows
309 }
310
311 fn ncols(&self) -> usize {
312 self.ncols
313 }
314}
315
316pub struct MySQLBinarySourceParser<'a> {
317 iter: QueryResult<'a, 'a, 'a, Binary>,
318 rowbuf: Vec<Row>,
319 ncols: usize,
320 current_col: usize,
321 current_row: usize,
322 is_finished: bool,
323}
324
325impl<'a> MySQLBinarySourceParser<'a> {
326 pub fn new(iter: QueryResult<'a, 'a, 'a, Binary>, schema: &[MySQLTypeSystem]) -> Self {
327 Self {
328 iter,
329 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
330 ncols: schema.len(),
331 current_row: 0,
332 current_col: 0,
333 is_finished: false,
334 }
335 }
336
337 #[throws(MySQLSourceError)]
338 fn next_loc(&mut self) -> (usize, usize) {
339 let ret = (self.current_row, self.current_col);
340 self.current_row += (self.current_col + 1) / self.ncols;
341 self.current_col = (self.current_col + 1) % self.ncols;
342 ret
343 }
344}
345
346impl<'a> PartitionParser<'a> for MySQLBinarySourceParser<'a> {
347 type TypeSystem = MySQLTypeSystem;
348 type Error = MySQLSourceError;
349
350 #[throws(MySQLSourceError)]
351 fn fetch_next(&mut self) -> (usize, bool) {
352 assert!(self.current_col == 0);
353 let remaining_rows = self.rowbuf.len() - self.current_row;
354 if remaining_rows > 0 {
355 return (remaining_rows, self.is_finished);
356 } else if self.is_finished {
357 return (0, self.is_finished);
358 }
359
360 if !self.rowbuf.is_empty() {
361 self.rowbuf.drain(..);
362 }
363
364 for _ in 0..DB_BUFFER_SIZE {
365 if let Some(item) = self.iter.next() {
366 self.rowbuf.push(item?);
367 } else {
368 self.is_finished = true;
369 break;
370 }
371 }
372 self.current_row = 0;
373 self.current_col = 0;
374
375 (self.rowbuf.len(), self.is_finished)
376 }
377}
378
379macro_rules! impl_produce_binary {
380 ($($t: ty,)+) => {
381 $(
382 impl<'r, 'a> Produce<'r, $t> for MySQLBinarySourceParser<'a> {
383 type Error = MySQLSourceError;
384
385 #[throws(MySQLSourceError)]
386 fn produce(&'r mut self) -> $t {
387 let (ridx, cidx) = self.next_loc()?;
388 let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
389 res
390 }
391 }
392
393 impl<'r, 'a> Produce<'r, Option<$t>> for MySQLBinarySourceParser<'a> {
394 type Error = MySQLSourceError;
395
396 #[throws(MySQLSourceError)]
397 fn produce(&'r mut self) -> Option<$t> {
398 let (ridx, cidx) = self.next_loc()?;
399 let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
400 res
401 }
402 }
403 )+
404 };
405}
406
407impl_produce_binary!(
408 i8,
409 i16,
410 i32,
411 i64,
412 u8,
413 u16,
414 u32,
415 u64,
416 f32,
417 f64,
418 NaiveDate,
419 NaiveTime,
420 NaiveDateTime,
421 Decimal,
422 String,
423 Vec<u8>,
424 Value,
425);
426
427pub struct MySQLTextSourceParser<'a> {
428 iter: QueryResult<'a, 'a, 'a, Text>,
429 rowbuf: Vec<Row>,
430 ncols: usize,
431 current_col: usize,
432 current_row: usize,
433 is_finished: bool,
434}
435
436impl<'a> MySQLTextSourceParser<'a> {
437 pub fn new(iter: QueryResult<'a, 'a, 'a, Text>, schema: &[MySQLTypeSystem]) -> Self {
438 Self {
439 iter,
440 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
441 ncols: schema.len(),
442 current_row: 0,
443 current_col: 0,
444 is_finished: false,
445 }
446 }
447
448 #[throws(MySQLSourceError)]
449 fn next_loc(&mut self) -> (usize, usize) {
450 let ret = (self.current_row, self.current_col);
451 self.current_row += (self.current_col + 1) / self.ncols;
452 self.current_col = (self.current_col + 1) % self.ncols;
453 ret
454 }
455}
456
457impl<'a> PartitionParser<'a> for MySQLTextSourceParser<'a> {
458 type TypeSystem = MySQLTypeSystem;
459 type Error = MySQLSourceError;
460
461 #[throws(MySQLSourceError)]
462 fn fetch_next(&mut self) -> (usize, bool) {
463 assert!(self.current_col == 0);
464 let remaining_rows = self.rowbuf.len() - self.current_row;
465 if remaining_rows > 0 {
466 return (remaining_rows, self.is_finished);
467 } else if self.is_finished {
468 return (0, self.is_finished);
469 }
470
471 if !self.rowbuf.is_empty() {
472 self.rowbuf.drain(..);
473 }
474 for _ in 0..DB_BUFFER_SIZE {
475 if let Some(item) = self.iter.next() {
476 self.rowbuf.push(item?);
477 } else {
478 self.is_finished = true;
479 break;
480 }
481 }
482 self.current_row = 0;
483 self.current_col = 0;
484 (self.rowbuf.len(), self.is_finished)
485 }
486}
487
488macro_rules! impl_produce_text {
489 ($($t: ty,)+) => {
490 $(
491 impl<'r, 'a> Produce<'r, $t> for MySQLTextSourceParser<'a> {
492 type Error = MySQLSourceError;
493
494 #[throws(MySQLSourceError)]
495 fn produce(&'r mut self) -> $t {
496 let (ridx, cidx) = self.next_loc()?;
497 let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
498 res
499 }
500 }
501
502 impl<'r, 'a> Produce<'r, Option<$t>> for MySQLTextSourceParser<'a> {
503 type Error = MySQLSourceError;
504
505 #[throws(MySQLSourceError)]
506 fn produce(&'r mut self) -> Option<$t> {
507 let (ridx, cidx) = self.next_loc()?;
508 let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
509 res
510 }
511 }
512 )+
513 };
514}
515
516impl_produce_text!(
517 i8,
518 i16,
519 i32,
520 i64,
521 u8,
522 u16,
523 u32,
524 u64,
525 f32,
526 f64,
527 NaiveDate,
528 NaiveTime,
529 NaiveDateTime,
530 Decimal,
531 String,
532 Vec<u8>,
533 Value,
534);