connectorx/
partition.rs

1use crate::errors::{ConnectorXOutError, OutResult};
2use crate::source_router::{SourceConn, SourceType};
3#[cfg(feature = "src_bigquery")]
4use crate::sources::bigquery::BigQueryDialect;
5#[cfg(feature = "src_mssql")]
6use crate::sources::mssql::{mssql_config, FloatN, IntN, MsSQLTypeSystem};
7#[cfg(feature = "src_mysql")]
8use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem};
9#[cfg(feature = "src_oracle")]
10use crate::sources::oracle::{connect_oracle, OracleDialect};
11#[cfg(feature = "src_postgres")]
12use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem};
13#[cfg(feature = "src_trino")]
14use crate::sources::trino::TrinoDialect;
15#[cfg(feature = "src_sqlite")]
16use crate::sql::get_partition_range_query_sep;
17use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery};
18use anyhow::anyhow;
19use fehler::{throw, throws};
20#[cfg(feature = "src_bigquery")]
21use gcp_bigquery_client;
22#[cfg(feature = "src_mysql")]
23use r2d2_mysql::mysql::{prelude::Queryable, Opts, Pool, Row};
24#[cfg(feature = "src_sqlite")]
25use rusqlite::{types::Type, Connection};
26#[cfg(feature = "src_postgres")]
27use rust_decimal::{prelude::ToPrimitive, Decimal};
28#[cfg(feature = "src_postgres")]
29use rust_decimal_macros::dec;
30#[cfg(feature = "src_mssql")]
31use sqlparser::dialect::MsSqlDialect;
32#[cfg(feature = "src_mysql")]
33use sqlparser::dialect::MySqlDialect;
34#[cfg(feature = "src_postgres")]
35use sqlparser::dialect::PostgreSqlDialect;
36#[cfg(feature = "src_sqlite")]
37use sqlparser::dialect::SQLiteDialect;
38#[cfg(feature = "src_mssql")]
39use tiberius::Client;
40#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))]
41use tokio::{net::TcpStream, runtime::Runtime};
42#[cfg(feature = "src_mssql")]
43use tokio_util::compat::TokioAsyncWriteCompatExt;
44use url::Url;
45
46pub struct PartitionQuery {
47    query: String,
48    column: String,
49    min: Option<i64>,
50    max: Option<i64>,
51    num: usize,
52}
53
54impl PartitionQuery {
55    pub fn new(query: &str, column: &str, min: Option<i64>, max: Option<i64>, num: usize) -> Self {
56        Self {
57            query: query.into(),
58            column: column.into(),
59            min,
60            max,
61            num,
62        }
63    }
64}
65
66pub fn partition(part: &PartitionQuery, source_conn: &SourceConn) -> OutResult<Vec<CXQuery>> {
67    let mut queries = vec![];
68    let num = part.num as i64;
69    let (min, max) = match (part.min, part.max) {
70        (None, None) => get_col_range(source_conn, &part.query, &part.column)?,
71        (Some(min), Some(max)) => (min, max),
72        _ => throw!(anyhow!(
73            "partition_query range can not be partially specified",
74        )),
75    };
76
77    let partition_size = (max - min + 1) / num;
78
79    for i in 0..num {
80        let lower = min + i * partition_size;
81        let upper = match i == num - 1 {
82            true => max + 1,
83            false => min + (i + 1) * partition_size,
84        };
85        let partition_query = get_part_query(source_conn, &part.query, &part.column, lower, upper)?;
86        queries.push(partition_query);
87    }
88    Ok(queries)
89}
90
91pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutResult<(i64, i64)> {
92    match source_conn.ty {
93        #[cfg(feature = "src_postgres")]
94        SourceType::Postgres => pg_get_partition_range(&source_conn.conn, query, col),
95        #[cfg(feature = "src_sqlite")]
96        SourceType::SQLite => sqlite_get_partition_range(&source_conn.conn, query, col),
97        #[cfg(feature = "src_mysql")]
98        SourceType::MySQL => mysql_get_partition_range(&source_conn.conn, query, col),
99        #[cfg(feature = "src_mssql")]
100        SourceType::MsSQL => mssql_get_partition_range(&source_conn.conn, query, col),
101        #[cfg(feature = "src_oracle")]
102        SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col),
103        #[cfg(feature = "src_bigquery")]
104        SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col),
105        #[cfg(feature = "src_trino")]
106        SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col),
107        _ => unimplemented!("{:?} not implemented!", source_conn.ty),
108    }
109}
110
111#[throws(ConnectorXOutError)]
112pub fn get_part_query(
113    source_conn: &SourceConn,
114    query: &str,
115    col: &str,
116    lower: i64,
117    upper: i64,
118) -> CXQuery<String> {
119    let query = match source_conn.ty {
120        #[cfg(feature = "src_postgres")]
121        SourceType::Postgres => {
122            single_col_partition_query(query, col, lower, upper, &PostgreSqlDialect {})?
123        }
124        #[cfg(feature = "src_sqlite")]
125        SourceType::SQLite => {
126            single_col_partition_query(query, col, lower, upper, &SQLiteDialect {})?
127        }
128        #[cfg(feature = "src_mysql")]
129        SourceType::MySQL => {
130            single_col_partition_query(query, col, lower, upper, &MySqlDialect {})?
131        }
132        #[cfg(feature = "src_mssql")]
133        SourceType::MsSQL => {
134            single_col_partition_query(query, col, lower, upper, &MsSqlDialect {})?
135        }
136        #[cfg(feature = "src_oracle")]
137        SourceType::Oracle => {
138            single_col_partition_query(query, col, lower, upper, &OracleDialect {})?
139        }
140        #[cfg(feature = "src_bigquery")]
141        SourceType::BigQuery => {
142            single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})?
143        }
144        #[cfg(feature = "src_trino")]
145        SourceType::Trino => {
146            single_col_partition_query(query, col, lower, upper, &TrinoDialect {})?
147        }
148        _ => unimplemented!("{:?} not implemented!", source_conn.ty),
149    };
150    CXQuery::Wrapped(query)
151}
152
153#[cfg(feature = "src_postgres")]
154#[throws(ConnectorXOutError)]
155fn pg_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
156    let (config, tls) = rewrite_tls_args(conn)?;
157    let mut client = match tls {
158        None => config.connect(postgres::NoTls)?,
159        Some(tls_conn) => config.connect(tls_conn)?,
160    };
161    let range_query = get_partition_range_query(query, col, &PostgreSqlDialect {})?;
162    let row = client.query_one(range_query.as_str(), &[])?;
163
164    let col_type = PostgresTypeSystem::from(row.columns()[0].type_());
165    let (min_v, max_v) = match col_type {
166        PostgresTypeSystem::Int2(_) => {
167            let min_v: Option<i16> = row.get(0);
168            let max_v: Option<i16> = row.get(1);
169            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
170        }
171        PostgresTypeSystem::Int4(_) => {
172            let min_v: Option<i32> = row.get(0);
173            let max_v: Option<i32> = row.get(1);
174            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
175        }
176        PostgresTypeSystem::Int8(_) => {
177            let min_v: Option<i64> = row.get(0);
178            let max_v: Option<i64> = row.get(1);
179            (min_v.unwrap_or(0), max_v.unwrap_or(0))
180        }
181        PostgresTypeSystem::Float4(_) => {
182            let min_v: Option<f32> = row.get(0);
183            let max_v: Option<f32> = row.get(1);
184            (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
185        }
186        PostgresTypeSystem::Float8(_) => {
187            let min_v: Option<f64> = row.get(0);
188            let max_v: Option<f64> = row.get(1);
189            (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
190        }
191        PostgresTypeSystem::Numeric(_) => {
192            let min_v: Option<Decimal> = row.get(0);
193            let max_v: Option<Decimal> = row.get(1);
194            (
195                min_v.unwrap_or(dec!(0.0)).to_i64().unwrap_or(0),
196                max_v.unwrap_or(dec!(0.0)).to_i64().unwrap_or(0),
197            )
198        }
199        _ => throw!(anyhow!(
200            "Partition can only be done on int or float columns"
201        )),
202    };
203
204    (min_v, max_v)
205}
206
207#[cfg(feature = "src_sqlite")]
208#[throws(ConnectorXOutError)]
209fn sqlite_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
210    // remove the first "sqlite://" manually since url.path is not correct for windows and for relative path
211    let conn = Connection::open(&conn.as_str()[9..])?;
212    // SQLite only optimize min max queries when there is only one aggregation
213    // https://www.sqlite.org/optoverview.html#minmax
214    let (min_query, max_query) = get_partition_range_query_sep(query, col, &SQLiteDialect {})?;
215    let mut error = None;
216    let min_v = conn.query_row(min_query.as_str(), [], |row| {
217        // declare type for count query will be None, only need to check the returned value type
218        let col_type = row.get_ref(0)?.data_type();
219        match col_type {
220            Type::Integer => row.get(0),
221            Type::Real => {
222                let v: f64 = row.get(0)?;
223                Ok(v as i64)
224            }
225            Type::Null => Ok(0),
226            _ => {
227                error = Some(anyhow!("Partition can only be done on integer columns"));
228                Ok(0)
229            }
230        }
231    })?;
232    match error {
233        None => {}
234        Some(e) => throw!(e),
235    }
236    let max_v = conn.query_row(max_query.as_str(), [], |row| {
237        let col_type = row.get_ref(0)?.data_type();
238        match col_type {
239            Type::Integer => row.get(0),
240            Type::Real => {
241                let v: f64 = row.get(0)?;
242                Ok(v as i64)
243            }
244            Type::Null => Ok(0),
245            _ => {
246                error = Some(anyhow!("Partition can only be done on integer columns"));
247                Ok(0)
248            }
249        }
250    })?;
251    match error {
252        None => {}
253        Some(e) => throw!(e),
254    }
255
256    (min_v, max_v)
257}
258
259#[cfg(feature = "src_mysql")]
260#[throws(ConnectorXOutError)]
261fn mysql_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
262    let pool = Pool::new(Opts::from_url(conn.as_str()).map_err(MySQLSourceError::MySQLUrlError)?)?;
263    let mut conn = pool.get_conn()?;
264    let range_query = get_partition_range_query(query, col, &MySqlDialect {})?;
265    let row: Row = conn
266        .query_first(range_query)?
267        .ok_or_else(|| anyhow!("mysql range: no row returns"))?;
268
269    let col_type =
270        MySQLTypeSystem::from((&row.columns()[0].column_type(), &row.columns()[0].flags()));
271
272    let (min_v, max_v) = match col_type {
273        MySQLTypeSystem::Tiny(_) => {
274            let min_v: Option<i8> = row
275                .get(0)
276                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
277            let max_v: Option<i8> = row
278                .get(1)
279                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
280            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
281        }
282        MySQLTypeSystem::Short(_) => {
283            let min_v: Option<i16> = row
284                .get(0)
285                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
286            let max_v: Option<i16> = row
287                .get(1)
288                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
289            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
290        }
291        MySQLTypeSystem::Int24(_) => {
292            let min_v: Option<i32> = row
293                .get(0)
294                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
295            let max_v: Option<i32> = row
296                .get(1)
297                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
298            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
299        }
300        MySQLTypeSystem::Long(_) => {
301            let min_v: Option<i64> = row
302                .get(0)
303                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
304            let max_v: Option<i64> = row
305                .get(1)
306                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
307            (min_v.unwrap_or(0), max_v.unwrap_or(0))
308        }
309        MySQLTypeSystem::LongLong(_) => {
310            let min_v: Option<i64> = row
311                .get(0)
312                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
313            let max_v: Option<i64> = row
314                .get(1)
315                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
316            (min_v.unwrap_or(0), max_v.unwrap_or(0))
317        }
318        MySQLTypeSystem::UTiny(_) => {
319            let min_v: Option<u8> = row
320                .get(0)
321                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
322            let max_v: Option<u8> = row
323                .get(1)
324                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
325            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
326        }
327        MySQLTypeSystem::UShort(_) => {
328            let min_v: Option<u16> = row
329                .get(0)
330                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
331            let max_v: Option<u16> = row
332                .get(1)
333                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
334            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
335        }
336        MySQLTypeSystem::UInt24(_) => {
337            let min_v: Option<u32> = row
338                .get(0)
339                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
340            let max_v: Option<u32> = row
341                .get(1)
342                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
343            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
344        }
345        MySQLTypeSystem::ULong(_) => {
346            let min_v: Option<u32> = row
347                .get(0)
348                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
349            let max_v: Option<u32> = row
350                .get(1)
351                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
352            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
353        }
354        MySQLTypeSystem::ULongLong(_) => {
355            let min_v: Option<u64> = row
356                .get(0)
357                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
358            let max_v: Option<u64> = row
359                .get(1)
360                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
361            (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
362        }
363        MySQLTypeSystem::Float(_) => {
364            let min_v: Option<f32> = row
365                .get(0)
366                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
367            let max_v: Option<f32> = row
368                .get(1)
369                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
370            (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
371        }
372        MySQLTypeSystem::Double(_) => {
373            let min_v: Option<f64> = row
374                .get(0)
375                .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
376            let max_v: Option<f64> = row
377                .get(1)
378                .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
379            (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
380        }
381        _ => throw!(anyhow!("Partition can only be done on int columns")),
382    };
383
384    (min_v, max_v)
385}
386
387#[cfg(feature = "src_mssql")]
388#[throws(ConnectorXOutError)]
389fn mssql_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
390    let rt = Runtime::new().expect("Failed to create runtime");
391    let config = mssql_config(conn)?;
392    let tcp = rt.block_on(TcpStream::connect(config.get_addr()))?;
393    tcp.set_nodelay(true)?;
394
395    let mut client = rt.block_on(Client::connect(config, tcp.compat_write()))?;
396
397    let range_query = get_partition_range_query(query, col, &MsSqlDialect {})?;
398    let query_result = rt.block_on(client.query(range_query.as_str(), &[]))?;
399    let row = rt.block_on(query_result.into_row())?.unwrap();
400
401    let col_type = MsSQLTypeSystem::from(&row.columns()[0].column_type());
402    let (min_v, max_v) = match col_type {
403        MsSQLTypeSystem::Tinyint(_) => {
404            let min_v: u8 = row.get(0).unwrap_or(0);
405            let max_v: u8 = row.get(1).unwrap_or(0);
406            (min_v as i64, max_v as i64)
407        }
408        MsSQLTypeSystem::Smallint(_) => {
409            let min_v: i16 = row.get(0).unwrap_or(0);
410            let max_v: i16 = row.get(1).unwrap_or(0);
411            (min_v as i64, max_v as i64)
412        }
413        MsSQLTypeSystem::Int(_) => {
414            let min_v: i32 = row.get(0).unwrap_or(0);
415            let max_v: i32 = row.get(1).unwrap_or(0);
416            (min_v as i64, max_v as i64)
417        }
418        MsSQLTypeSystem::Bigint(_) => {
419            let min_v: i64 = row.get(0).unwrap_or(0);
420            let max_v: i64 = row.get(1).unwrap_or(0);
421            (min_v, max_v)
422        }
423        MsSQLTypeSystem::Intn(_) => {
424            let min_v: IntN = row.get(0).unwrap_or(IntN(0));
425            let max_v: IntN = row.get(1).unwrap_or(IntN(0));
426            (min_v.0, max_v.0)
427        }
428        MsSQLTypeSystem::Float24(_) => {
429            let min_v: f32 = row.get(0).unwrap_or(0.0);
430            let max_v: f32 = row.get(1).unwrap_or(0.0);
431            (min_v as i64, max_v as i64)
432        }
433        MsSQLTypeSystem::Float53(_) => {
434            let min_v: f64 = row.get(0).unwrap_or(0.0);
435            let max_v: f64 = row.get(1).unwrap_or(0.0);
436            (min_v as i64, max_v as i64)
437        }
438        MsSQLTypeSystem::Floatn(_) => {
439            let min_v: FloatN = row.get(0).unwrap_or(FloatN(0.0));
440            let max_v: FloatN = row.get(1).unwrap_or(FloatN(0.0));
441            (min_v.0 as i64, max_v.0 as i64)
442        }
443        _ => throw!(anyhow!(
444            "Partition can only be done on int or float columns"
445        )),
446    };
447
448    (min_v, max_v)
449}
450
451#[cfg(feature = "src_oracle")]
452#[throws(ConnectorXOutError)]
453fn oracle_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
454    let connector = connect_oracle(conn)?;
455    let conn = connector.connect()?;
456    let range_query = get_partition_range_query(query, col, &OracleDialect {})?;
457    let row = conn.query_row(range_query.as_str(), &[])?;
458    let min_v: i64 = row.get(0).unwrap_or(0);
459    let max_v: i64 = row.get(1).unwrap_or(0);
460    (min_v, max_v)
461}
462
463#[cfg(feature = "src_bigquery")]
464#[throws(ConnectorXOutError)] // TODO
465fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
466    let rt = Runtime::new().expect("Failed to create runtime");
467    let url = Url::parse(conn.as_str())?;
468    let sa_key_path = url.path();
469    let client = rt.block_on(gcp_bigquery_client::Client::from_service_account_key_file(
470        sa_key_path,
471    ))?;
472
473    let auth_data = std::fs::read_to_string(sa_key_path)?;
474    let auth_json: serde_json::Value = serde_json::from_str(&auth_data)?;
475    let project_id = auth_json
476        .get("project_id")
477        .ok_or_else(|| anyhow!("Cannot get project_id from auth file"))?
478        .as_str()
479        .ok_or_else(|| anyhow!("Cannot get project_id as string from auth file"))?;
480    let range_query = get_partition_range_query(query, col, &BigQueryDialect {})?;
481
482    let query_result = rt.block_on(client.job().query(
483        project_id,
484        gcp_bigquery_client::model::query_request::QueryRequest::new(range_query.as_str()),
485    ))?;
486    let mut rs = gcp_bigquery_client::model::query_response::ResultSet::new_from_query_response(
487        query_result,
488    );
489    rs.next_row();
490    let min_v = rs.get_i64(0)?.unwrap_or(0);
491    let max_v = rs.get_i64(1)?.unwrap_or(0);
492
493    (min_v, max_v)
494}
495
496#[cfg(feature = "src_trino")]
497#[throws(ConnectorXOutError)]
498fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
499    use prusto::{auth::Auth, ClientBuilder};
500
501    use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult};
502
503    let rt = Runtime::new().expect("Failed to create runtime");
504
505    let username = match conn.username() {
506        "" => "connectorx",
507        username => username,
508    };
509
510    let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned())
511        .port(conn.port().unwrap_or(8080))
512        .ssl(prusto::ssl::Ssl { root_cert: None })
513        .secure(conn.scheme() == "trino+https")
514        .catalog(conn.path_segments().unwrap().last().unwrap_or("hive"));
515
516    let builder = match conn.password() {
517        None => builder,
518        Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))),
519    };
520
521    let client = builder
522        .build()
523        .map_err(|e| anyhow!("Failed to build client: {}", e))?;
524
525    let range_query = get_partition_range_query(query, col, &TrinoDialect {})?;
526    let query_result = rt.block_on(client.get_all::<TrinoPartitionQueryResult>(range_query));
527
528    let query_result = match query_result {
529        Ok(query_result) => Ok(query_result.into_vec()),
530        Err(e) => match e {
531            prusto::error::Error::EmptyData => {
532                Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }])
533            }
534            _ => Err(anyhow!("Failed to get query result: {}", e)),
535        },
536    }?;
537
538    let result = query_result
539        .first()
540        .unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 });
541
542    (result._col0, result._col1)
543}