connectorx/sources/mysql/
mod.rs

1//! Source implementation for MySQL database.
2
3mod errors;
4mod typesystem;
5
6pub use self::errors::MySQLSourceError;
7use crate::constants::DB_BUFFER_SIZE;
8use crate::{
9    data_order::DataOrder,
10    errors::ConnectorXError,
11    sources::{PartitionParser, Produce, Source, SourcePartition},
12    sql::{count_query, limit1_query, CXQuery},
13};
14use anyhow::anyhow;
15use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
16use fehler::{throw, throws};
17use log::{debug, warn};
18use r2d2::{Pool, PooledConnection};
19use r2d2_mysql::{
20    mysql::{
21        consts::{
22            ColumnFlags as MySQLColumnFlags, ColumnType as MySQLColumnType, UTF8MB4_GENERAL_CI,
23            UTF8_GENERAL_CI,
24        },
25        prelude::Queryable,
26        Binary, Opts, OptsBuilder, QueryResult, Row, Text,
27    },
28    MySqlConnectionManager,
29};
30use rust_decimal::Decimal;
31use serde_json::Value;
32use sqlparser::dialect::MySqlDialect;
33use std::marker::PhantomData;
34pub use typesystem::MySQLTypeSystem;
35
36type MysqlConn = PooledConnection<MySqlConnectionManager>;
37
38pub enum BinaryProtocol {}
39pub enum TextProtocol {}
40
41#[throws(MySQLSourceError)]
42fn get_total_rows(conn: &mut MysqlConn, query: &CXQuery<String>) -> usize {
43    conn.query_first(&count_query(query, &MySqlDialect {})?)?
44        .ok_or_else(|| anyhow!("mysql failed to get the count of query: {}", query))?
45}
46
47pub struct MySQLSource<P> {
48    pool: Pool<MySqlConnectionManager>,
49    origin_query: Option<String>,
50    queries: Vec<CXQuery<String>>,
51    names: Vec<String>,
52    schema: Vec<MySQLTypeSystem>,
53    pre_execution_queries: Option<Vec<String>>,
54    _protocol: PhantomData<P>,
55}
56
57impl<P> MySQLSource<P> {
58    #[throws(MySQLSourceError)]
59    pub fn new(conn: &str, nconn: usize) -> Self {
60        let manager = MySqlConnectionManager::new(OptsBuilder::from_opts(Opts::from_url(conn)?));
61        let pool = r2d2::Pool::builder()
62            .max_size(nconn as u32)
63            .build(manager)?;
64
65        Self {
66            pool,
67            origin_query: None,
68            queries: vec![],
69            names: vec![],
70            schema: vec![],
71            pre_execution_queries: None,
72            _protocol: PhantomData,
73        }
74    }
75}
76
77impl<P> Source for MySQLSource<P>
78where
79    MySQLSourcePartition<P>:
80        SourcePartition<TypeSystem = MySQLTypeSystem, Error = MySQLSourceError>,
81    P: Send,
82{
83    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
84    type Partition = MySQLSourcePartition<P>;
85    type TypeSystem = MySQLTypeSystem;
86    type Error = MySQLSourceError;
87
88    #[throws(MySQLSourceError)]
89    fn set_data_order(&mut self, data_order: DataOrder) {
90        if !matches!(data_order, DataOrder::RowMajor) {
91            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
92        }
93    }
94
95    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
96        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
97    }
98
99    fn set_origin_query(&mut self, query: Option<String>) {
100        self.origin_query = query;
101    }
102
103    fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) {
104        self.pre_execution_queries = pre_execution_queries.map(|s| s.to_vec());
105    }
106
107    #[throws(MySQLSourceError)]
108    fn fetch_metadata(&mut self) {
109        assert!(!self.queries.is_empty());
110
111        let mut conn = self.pool.get()?;
112        let server_version_post_5_5_3 = conn.server_version() >= (5, 5, 3);
113
114        let first_query = &self.queries[0];
115
116        match conn.prep(first_query) {
117            Ok(stmt) => {
118                let (names, types) = stmt
119                    .columns()
120                    .iter()
121                    .map(|col| {
122                        let col_name = col.name_str().to_string();
123                        let col_type = col.column_type();
124                        let col_flags = col.flags();
125                        let charset = col.character_set();
126                        let charset_is_utf8 = (server_version_post_5_5_3
127                            && charset == UTF8MB4_GENERAL_CI)
128                            || (!server_version_post_5_5_3 && charset == UTF8_GENERAL_CI);
129                        if charset_is_utf8
130                            && (col_type == MySQLColumnType::MYSQL_TYPE_LONG_BLOB
131                                || col_type == MySQLColumnType::MYSQL_TYPE_BLOB
132                                || col_type == MySQLColumnType::MYSQL_TYPE_MEDIUM_BLOB
133                                || col_type == MySQLColumnType::MYSQL_TYPE_TINY_BLOB)
134                        {
135                            return (
136                                col_name,
137                                MySQLTypeSystem::Char(
138                                    !col_flags.contains(MySQLColumnFlags::NOT_NULL_FLAG),
139                                ),
140                            );
141                        }
142                        let d = MySQLTypeSystem::from((&col_type, &col_flags));
143                        (col_name, d)
144                    })
145                    .unzip();
146                self.names = names;
147                self.schema = types;
148            }
149            Err(e) => {
150                warn!(
151                    "mysql text prepared statement error: {:?}, switch to limit1 method",
152                    e
153                );
154                for (i, query) in self.queries.iter().enumerate() {
155                    // assuming all the partition queries yield same schema
156                    match conn
157                        .query_first::<Row, _>(limit1_query(query, &MySqlDialect {})?.as_str())
158                    {
159                        Ok(Some(row)) => {
160                            let (names, types) = row
161                                .columns_ref()
162                                .iter()
163                                .map(|col| {
164                                    (
165                                        col.name_str().to_string(),
166                                        MySQLTypeSystem::from((&col.column_type(), &col.flags())),
167                                    )
168                                })
169                                .unzip();
170                            self.names = names;
171                            self.schema = types;
172                            return;
173                        }
174                        Ok(None) => {}
175                        Err(e) if i == self.queries.len() - 1 => {
176                            // tried the last query but still get an error
177                            debug!("cannot get metadata for '{}', try next query: {}", query, e);
178                            throw!(e)
179                        }
180                        Err(_) => {}
181                    }
182                }
183
184                // tried all queries but all get empty result set
185                let iter = conn.query_iter(self.queries[0].as_str())?;
186                let (names, types) = iter
187                    .columns()
188                    .as_ref()
189                    .iter()
190                    .map(|col| {
191                        (
192                            col.name_str().to_string(),
193                            MySQLTypeSystem::VarChar(false), // set all columns as string (align with pandas)
194                        )
195                    })
196                    .unzip();
197                self.names = names;
198                self.schema = types;
199            }
200        }
201    }
202
203    #[throws(MySQLSourceError)]
204    fn result_rows(&mut self) -> Option<usize> {
205        match &self.origin_query {
206            Some(q) => {
207                let cxq = CXQuery::Naked(q.clone());
208                let mut conn = self.pool.get()?;
209                let nrows = get_total_rows(&mut conn, &cxq)?;
210                Some(nrows)
211            }
212            None => None,
213        }
214    }
215
216    fn names(&self) -> Vec<String> {
217        self.names.clone()
218    }
219
220    fn schema(&self) -> Vec<Self::TypeSystem> {
221        self.schema.clone()
222    }
223
224    #[throws(MySQLSourceError)]
225    fn partition(self) -> Vec<Self::Partition> {
226        let mut ret = vec![];
227        for query in self.queries {
228            let mut conn = self.pool.get()?;
229
230            if let Some(pre_queries) = &self.pre_execution_queries {
231                for pre_query in pre_queries {
232                    conn.query_drop(pre_query)?;
233                }
234            }
235
236            ret.push(MySQLSourcePartition::new(conn, &query, &self.schema));
237        }
238        ret
239    }
240}
241
242pub struct MySQLSourcePartition<P> {
243    conn: MysqlConn,
244    query: CXQuery<String>,
245    schema: Vec<MySQLTypeSystem>,
246    nrows: usize,
247    ncols: usize,
248    _protocol: PhantomData<P>,
249}
250
251impl<P> MySQLSourcePartition<P> {
252    pub fn new(conn: MysqlConn, query: &CXQuery<String>, schema: &[MySQLTypeSystem]) -> Self {
253        Self {
254            conn,
255            query: query.clone(),
256            schema: schema.to_vec(),
257            nrows: 0,
258            ncols: schema.len(),
259            _protocol: PhantomData,
260        }
261    }
262}
263
264impl SourcePartition for MySQLSourcePartition<BinaryProtocol> {
265    type TypeSystem = MySQLTypeSystem;
266    type Parser<'a> = MySQLBinarySourceParser<'a>;
267    type Error = MySQLSourceError;
268
269    #[throws(MySQLSourceError)]
270    fn result_rows(&mut self) {
271        self.nrows = get_total_rows(&mut self.conn, &self.query)?;
272    }
273
274    #[throws(MySQLSourceError)]
275    fn parser(&mut self) -> Self::Parser<'_> {
276        let stmt = self.conn.prep(self.query.as_str())?;
277        let iter = self.conn.exec_iter(stmt, ())?;
278        MySQLBinarySourceParser::new(iter, &self.schema)
279    }
280
281    fn nrows(&self) -> usize {
282        self.nrows
283    }
284
285    fn ncols(&self) -> usize {
286        self.ncols
287    }
288}
289
290impl SourcePartition for MySQLSourcePartition<TextProtocol> {
291    type TypeSystem = MySQLTypeSystem;
292    type Parser<'a> = MySQLTextSourceParser<'a>;
293    type Error = MySQLSourceError;
294
295    #[throws(MySQLSourceError)]
296    fn result_rows(&mut self) {
297        self.nrows = get_total_rows(&mut self.conn, &self.query)?;
298    }
299
300    #[throws(MySQLSourceError)]
301    fn parser(&mut self) -> Self::Parser<'_> {
302        let query = self.query.clone();
303        let iter = self.conn.query_iter(query)?;
304        MySQLTextSourceParser::new(iter, &self.schema)
305    }
306
307    fn nrows(&self) -> usize {
308        self.nrows
309    }
310
311    fn ncols(&self) -> usize {
312        self.ncols
313    }
314}
315
316pub struct MySQLBinarySourceParser<'a> {
317    iter: QueryResult<'a, 'a, 'a, Binary>,
318    rowbuf: Vec<Row>,
319    ncols: usize,
320    current_col: usize,
321    current_row: usize,
322    is_finished: bool,
323}
324
325impl<'a> MySQLBinarySourceParser<'a> {
326    pub fn new(iter: QueryResult<'a, 'a, 'a, Binary>, schema: &[MySQLTypeSystem]) -> Self {
327        Self {
328            iter,
329            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
330            ncols: schema.len(),
331            current_row: 0,
332            current_col: 0,
333            is_finished: false,
334        }
335    }
336
337    #[throws(MySQLSourceError)]
338    fn next_loc(&mut self) -> (usize, usize) {
339        let ret = (self.current_row, self.current_col);
340        self.current_row += (self.current_col + 1) / self.ncols;
341        self.current_col = (self.current_col + 1) % self.ncols;
342        ret
343    }
344}
345
346impl<'a> PartitionParser<'a> for MySQLBinarySourceParser<'a> {
347    type TypeSystem = MySQLTypeSystem;
348    type Error = MySQLSourceError;
349
350    #[throws(MySQLSourceError)]
351    fn fetch_next(&mut self) -> (usize, bool) {
352        assert!(self.current_col == 0);
353        let remaining_rows = self.rowbuf.len() - self.current_row;
354        if remaining_rows > 0 {
355            return (remaining_rows, self.is_finished);
356        } else if self.is_finished {
357            return (0, self.is_finished);
358        }
359
360        if !self.rowbuf.is_empty() {
361            self.rowbuf.drain(..);
362        }
363
364        for _ in 0..DB_BUFFER_SIZE {
365            if let Some(item) = self.iter.next() {
366                self.rowbuf.push(item?);
367            } else {
368                self.is_finished = true;
369                break;
370            }
371        }
372        self.current_row = 0;
373        self.current_col = 0;
374
375        (self.rowbuf.len(), self.is_finished)
376    }
377}
378
379macro_rules! impl_produce_binary {
380    ($($t: ty,)+) => {
381        $(
382            impl<'r, 'a> Produce<'r, $t> for MySQLBinarySourceParser<'a> {
383                type Error = MySQLSourceError;
384
385                #[throws(MySQLSourceError)]
386                fn produce(&'r mut self) -> $t {
387                    let (ridx, cidx) = self.next_loc()?;
388                    let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
389                    res
390                }
391            }
392
393            impl<'r, 'a> Produce<'r, Option<$t>> for MySQLBinarySourceParser<'a> {
394                type Error = MySQLSourceError;
395
396                #[throws(MySQLSourceError)]
397                fn produce(&'r mut self) -> Option<$t> {
398                    let (ridx, cidx) = self.next_loc()?;
399                    let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
400                    res
401                }
402            }
403        )+
404    };
405}
406
407impl_produce_binary!(
408    i8,
409    i16,
410    i32,
411    i64,
412    u8,
413    u16,
414    u32,
415    u64,
416    f32,
417    f64,
418    NaiveDate,
419    NaiveTime,
420    NaiveDateTime,
421    Decimal,
422    String,
423    Vec<u8>,
424    Value,
425);
426
427pub struct MySQLTextSourceParser<'a> {
428    iter: QueryResult<'a, 'a, 'a, Text>,
429    rowbuf: Vec<Row>,
430    ncols: usize,
431    current_col: usize,
432    current_row: usize,
433    is_finished: bool,
434}
435
436impl<'a> MySQLTextSourceParser<'a> {
437    pub fn new(iter: QueryResult<'a, 'a, 'a, Text>, schema: &[MySQLTypeSystem]) -> Self {
438        Self {
439            iter,
440            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
441            ncols: schema.len(),
442            current_row: 0,
443            current_col: 0,
444            is_finished: false,
445        }
446    }
447
448    #[throws(MySQLSourceError)]
449    fn next_loc(&mut self) -> (usize, usize) {
450        let ret = (self.current_row, self.current_col);
451        self.current_row += (self.current_col + 1) / self.ncols;
452        self.current_col = (self.current_col + 1) % self.ncols;
453        ret
454    }
455}
456
457impl<'a> PartitionParser<'a> for MySQLTextSourceParser<'a> {
458    type TypeSystem = MySQLTypeSystem;
459    type Error = MySQLSourceError;
460
461    #[throws(MySQLSourceError)]
462    fn fetch_next(&mut self) -> (usize, bool) {
463        assert!(self.current_col == 0);
464        let remaining_rows = self.rowbuf.len() - self.current_row;
465        if remaining_rows > 0 {
466            return (remaining_rows, self.is_finished);
467        } else if self.is_finished {
468            return (0, self.is_finished);
469        }
470
471        if !self.rowbuf.is_empty() {
472            self.rowbuf.drain(..);
473        }
474        for _ in 0..DB_BUFFER_SIZE {
475            if let Some(item) = self.iter.next() {
476                self.rowbuf.push(item?);
477            } else {
478                self.is_finished = true;
479                break;
480            }
481        }
482        self.current_row = 0;
483        self.current_col = 0;
484        (self.rowbuf.len(), self.is_finished)
485    }
486}
487
488macro_rules! impl_produce_text {
489    ($($t: ty,)+) => {
490        $(
491            impl<'r, 'a> Produce<'r, $t> for MySQLTextSourceParser<'a> {
492                type Error = MySQLSourceError;
493
494                #[throws(MySQLSourceError)]
495                fn produce(&'r mut self) -> $t {
496                    let (ridx, cidx) = self.next_loc()?;
497                    let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
498                    res
499                }
500            }
501
502            impl<'r, 'a> Produce<'r, Option<$t>> for MySQLTextSourceParser<'a> {
503                type Error = MySQLSourceError;
504
505                #[throws(MySQLSourceError)]
506                fn produce(&'r mut self) -> Option<$t> {
507                    let (ridx, cidx) = self.next_loc()?;
508                    let res = self.rowbuf[ridx].take(cidx).ok_or_else(|| anyhow!("mysql cannot parse at position: ({}, {})", ridx, cidx))?;
509                    res
510                }
511            }
512        )+
513    };
514}
515
516impl_produce_text!(
517    i8,
518    i16,
519    i32,
520    i64,
521    u8,
522    u16,
523    u32,
524    u64,
525    f32,
526    f64,
527    NaiveDate,
528    NaiveTime,
529    NaiveDateTime,
530    Decimal,
531    String,
532    Vec<u8>,
533    Value,
534);