1use crate::errors::{ConnectorXOutError, OutResult};
2use crate::source_router::{SourceConn, SourceType};
3#[cfg(feature = "src_bigquery")]
4use crate::sources::bigquery::BigQueryDialect;
5#[cfg(feature = "src_mssql")]
6use crate::sources::mssql::{mssql_config, FloatN, IntN, MsSQLTypeSystem};
7#[cfg(feature = "src_mysql")]
8use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem};
9#[cfg(feature = "src_oracle")]
10use crate::sources::oracle::{connect_oracle, OracleDialect};
11#[cfg(feature = "src_postgres")]
12use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem};
13#[cfg(feature = "src_trino")]
14use crate::sources::trino::TrinoDialect;
15#[cfg(feature = "src_sqlite")]
16use crate::sql::get_partition_range_query_sep;
17use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery};
18use anyhow::anyhow;
19use fehler::{throw, throws};
20#[cfg(feature = "src_bigquery")]
21use gcp_bigquery_client;
22#[cfg(feature = "src_mysql")]
23use r2d2_mysql::mysql::{prelude::Queryable, Opts, Pool, Row};
24#[cfg(feature = "src_sqlite")]
25use rusqlite::{types::Type, Connection};
26#[cfg(feature = "src_postgres")]
27use rust_decimal::{prelude::ToPrimitive, Decimal};
28#[cfg(feature = "src_postgres")]
29use rust_decimal_macros::dec;
30#[cfg(feature = "src_mssql")]
31use sqlparser::dialect::MsSqlDialect;
32#[cfg(feature = "src_mysql")]
33use sqlparser::dialect::MySqlDialect;
34#[cfg(feature = "src_postgres")]
35use sqlparser::dialect::PostgreSqlDialect;
36#[cfg(feature = "src_sqlite")]
37use sqlparser::dialect::SQLiteDialect;
38#[cfg(feature = "src_mssql")]
39use tiberius::Client;
40#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))]
41use tokio::{net::TcpStream, runtime::Runtime};
42#[cfg(feature = "src_mssql")]
43use tokio_util::compat::TokioAsyncWriteCompatExt;
44use url::Url;
45
46pub struct PartitionQuery {
47 query: String,
48 column: String,
49 min: Option<i64>,
50 max: Option<i64>,
51 num: usize,
52}
53
54impl PartitionQuery {
55 pub fn new(query: &str, column: &str, min: Option<i64>, max: Option<i64>, num: usize) -> Self {
56 Self {
57 query: query.into(),
58 column: column.into(),
59 min,
60 max,
61 num,
62 }
63 }
64}
65
66pub fn partition(part: &PartitionQuery, source_conn: &SourceConn) -> OutResult<Vec<CXQuery>> {
67 let mut queries = vec![];
68 let num = part.num as i64;
69 let (min, max) = match (part.min, part.max) {
70 (None, None) => get_col_range(source_conn, &part.query, &part.column)?,
71 (Some(min), Some(max)) => (min, max),
72 _ => throw!(anyhow!(
73 "partition_query range can not be partially specified",
74 )),
75 };
76
77 let partition_size = (max - min + 1) / num;
78
79 for i in 0..num {
80 let lower = min + i * partition_size;
81 let upper = match i == num - 1 {
82 true => max + 1,
83 false => min + (i + 1) * partition_size,
84 };
85 let partition_query = get_part_query(source_conn, &part.query, &part.column, lower, upper)?;
86 queries.push(partition_query);
87 }
88 Ok(queries)
89}
90
91pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutResult<(i64, i64)> {
92 match source_conn.ty {
93 #[cfg(feature = "src_postgres")]
94 SourceType::Postgres => pg_get_partition_range(&source_conn.conn, query, col),
95 #[cfg(feature = "src_sqlite")]
96 SourceType::SQLite => sqlite_get_partition_range(&source_conn.conn, query, col),
97 #[cfg(feature = "src_mysql")]
98 SourceType::MySQL => mysql_get_partition_range(&source_conn.conn, query, col),
99 #[cfg(feature = "src_mssql")]
100 SourceType::MsSQL => mssql_get_partition_range(&source_conn.conn, query, col),
101 #[cfg(feature = "src_oracle")]
102 SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col),
103 #[cfg(feature = "src_bigquery")]
104 SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col),
105 #[cfg(feature = "src_trino")]
106 SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col),
107 _ => unimplemented!("{:?} not implemented!", source_conn.ty),
108 }
109}
110
111#[throws(ConnectorXOutError)]
112pub fn get_part_query(
113 source_conn: &SourceConn,
114 query: &str,
115 col: &str,
116 lower: i64,
117 upper: i64,
118) -> CXQuery<String> {
119 let query = match source_conn.ty {
120 #[cfg(feature = "src_postgres")]
121 SourceType::Postgres => {
122 single_col_partition_query(query, col, lower, upper, &PostgreSqlDialect {})?
123 }
124 #[cfg(feature = "src_sqlite")]
125 SourceType::SQLite => {
126 single_col_partition_query(query, col, lower, upper, &SQLiteDialect {})?
127 }
128 #[cfg(feature = "src_mysql")]
129 SourceType::MySQL => {
130 single_col_partition_query(query, col, lower, upper, &MySqlDialect {})?
131 }
132 #[cfg(feature = "src_mssql")]
133 SourceType::MsSQL => {
134 single_col_partition_query(query, col, lower, upper, &MsSqlDialect {})?
135 }
136 #[cfg(feature = "src_oracle")]
137 SourceType::Oracle => {
138 single_col_partition_query(query, col, lower, upper, &OracleDialect {})?
139 }
140 #[cfg(feature = "src_bigquery")]
141 SourceType::BigQuery => {
142 single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})?
143 }
144 #[cfg(feature = "src_trino")]
145 SourceType::Trino => {
146 single_col_partition_query(query, col, lower, upper, &TrinoDialect {})?
147 }
148 _ => unimplemented!("{:?} not implemented!", source_conn.ty),
149 };
150 CXQuery::Wrapped(query)
151}
152
153#[cfg(feature = "src_postgres")]
154#[throws(ConnectorXOutError)]
155fn pg_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
156 let (config, tls) = rewrite_tls_args(conn)?;
157 let mut client = match tls {
158 None => config.connect(postgres::NoTls)?,
159 Some(tls_conn) => config.connect(tls_conn)?,
160 };
161 let range_query = get_partition_range_query(query, col, &PostgreSqlDialect {})?;
162 let row = client.query_one(range_query.as_str(), &[])?;
163
164 let col_type = PostgresTypeSystem::from(row.columns()[0].type_());
165 let (min_v, max_v) = match col_type {
166 PostgresTypeSystem::Int2(_) => {
167 let min_v: Option<i16> = row.get(0);
168 let max_v: Option<i16> = row.get(1);
169 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
170 }
171 PostgresTypeSystem::Int4(_) => {
172 let min_v: Option<i32> = row.get(0);
173 let max_v: Option<i32> = row.get(1);
174 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
175 }
176 PostgresTypeSystem::Int8(_) => {
177 let min_v: Option<i64> = row.get(0);
178 let max_v: Option<i64> = row.get(1);
179 (min_v.unwrap_or(0), max_v.unwrap_or(0))
180 }
181 PostgresTypeSystem::Float4(_) => {
182 let min_v: Option<f32> = row.get(0);
183 let max_v: Option<f32> = row.get(1);
184 (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
185 }
186 PostgresTypeSystem::Float8(_) => {
187 let min_v: Option<f64> = row.get(0);
188 let max_v: Option<f64> = row.get(1);
189 (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
190 }
191 PostgresTypeSystem::Numeric(_) => {
192 let min_v: Option<Decimal> = row.get(0);
193 let max_v: Option<Decimal> = row.get(1);
194 (
195 min_v.unwrap_or(dec!(0.0)).to_i64().unwrap_or(0),
196 max_v.unwrap_or(dec!(0.0)).to_i64().unwrap_or(0),
197 )
198 }
199 _ => throw!(anyhow!(
200 "Partition can only be done on int or float columns"
201 )),
202 };
203
204 (min_v, max_v)
205}
206
207#[cfg(feature = "src_sqlite")]
208#[throws(ConnectorXOutError)]
209fn sqlite_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
210 let conn = Connection::open(&conn.as_str()[9..])?;
212 let (min_query, max_query) = get_partition_range_query_sep(query, col, &SQLiteDialect {})?;
215 let mut error = None;
216 let min_v = conn.query_row(min_query.as_str(), [], |row| {
217 let col_type = row.get_ref(0)?.data_type();
219 match col_type {
220 Type::Integer => row.get(0),
221 Type::Real => {
222 let v: f64 = row.get(0)?;
223 Ok(v as i64)
224 }
225 Type::Null => Ok(0),
226 _ => {
227 error = Some(anyhow!("Partition can only be done on integer columns"));
228 Ok(0)
229 }
230 }
231 })?;
232 match error {
233 None => {}
234 Some(e) => throw!(e),
235 }
236 let max_v = conn.query_row(max_query.as_str(), [], |row| {
237 let col_type = row.get_ref(0)?.data_type();
238 match col_type {
239 Type::Integer => row.get(0),
240 Type::Real => {
241 let v: f64 = row.get(0)?;
242 Ok(v as i64)
243 }
244 Type::Null => Ok(0),
245 _ => {
246 error = Some(anyhow!("Partition can only be done on integer columns"));
247 Ok(0)
248 }
249 }
250 })?;
251 match error {
252 None => {}
253 Some(e) => throw!(e),
254 }
255
256 (min_v, max_v)
257}
258
259#[cfg(feature = "src_mysql")]
260#[throws(ConnectorXOutError)]
261fn mysql_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
262 let pool = Pool::new(Opts::from_url(conn.as_str()).map_err(MySQLSourceError::MySQLUrlError)?)?;
263 let mut conn = pool.get_conn()?;
264 let range_query = get_partition_range_query(query, col, &MySqlDialect {})?;
265 let row: Row = conn
266 .query_first(range_query)?
267 .ok_or_else(|| anyhow!("mysql range: no row returns"))?;
268
269 let col_type =
270 MySQLTypeSystem::from((&row.columns()[0].column_type(), &row.columns()[0].flags()));
271
272 let (min_v, max_v) = match col_type {
273 MySQLTypeSystem::Tiny(_) => {
274 let min_v: Option<i8> = row
275 .get(0)
276 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
277 let max_v: Option<i8> = row
278 .get(1)
279 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
280 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
281 }
282 MySQLTypeSystem::Short(_) => {
283 let min_v: Option<i16> = row
284 .get(0)
285 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
286 let max_v: Option<i16> = row
287 .get(1)
288 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
289 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
290 }
291 MySQLTypeSystem::Int24(_) => {
292 let min_v: Option<i32> = row
293 .get(0)
294 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
295 let max_v: Option<i32> = row
296 .get(1)
297 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
298 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
299 }
300 MySQLTypeSystem::Long(_) => {
301 let min_v: Option<i64> = row
302 .get(0)
303 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
304 let max_v: Option<i64> = row
305 .get(1)
306 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
307 (min_v.unwrap_or(0), max_v.unwrap_or(0))
308 }
309 MySQLTypeSystem::LongLong(_) => {
310 let min_v: Option<i64> = row
311 .get(0)
312 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
313 let max_v: Option<i64> = row
314 .get(1)
315 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
316 (min_v.unwrap_or(0), max_v.unwrap_or(0))
317 }
318 MySQLTypeSystem::UTiny(_) => {
319 let min_v: Option<u8> = row
320 .get(0)
321 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
322 let max_v: Option<u8> = row
323 .get(1)
324 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
325 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
326 }
327 MySQLTypeSystem::UShort(_) => {
328 let min_v: Option<u16> = row
329 .get(0)
330 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
331 let max_v: Option<u16> = row
332 .get(1)
333 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
334 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
335 }
336 MySQLTypeSystem::UInt24(_) => {
337 let min_v: Option<u32> = row
338 .get(0)
339 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
340 let max_v: Option<u32> = row
341 .get(1)
342 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
343 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
344 }
345 MySQLTypeSystem::ULong(_) => {
346 let min_v: Option<u32> = row
347 .get(0)
348 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
349 let max_v: Option<u32> = row
350 .get(1)
351 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
352 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
353 }
354 MySQLTypeSystem::ULongLong(_) => {
355 let min_v: Option<u64> = row
356 .get(0)
357 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
358 let max_v: Option<u64> = row
359 .get(1)
360 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
361 (min_v.unwrap_or(0) as i64, max_v.unwrap_or(0) as i64)
362 }
363 MySQLTypeSystem::Float(_) => {
364 let min_v: Option<f32> = row
365 .get(0)
366 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
367 let max_v: Option<f32> = row
368 .get(1)
369 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
370 (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
371 }
372 MySQLTypeSystem::Double(_) => {
373 let min_v: Option<f64> = row
374 .get(0)
375 .ok_or_else(|| anyhow!("mysql range: cannot get min value"))?;
376 let max_v: Option<f64> = row
377 .get(1)
378 .ok_or_else(|| anyhow!("mysql range: cannot get max value"))?;
379 (min_v.unwrap_or(0.0) as i64, max_v.unwrap_or(0.0) as i64)
380 }
381 _ => throw!(anyhow!("Partition can only be done on int columns")),
382 };
383
384 (min_v, max_v)
385}
386
387#[cfg(feature = "src_mssql")]
388#[throws(ConnectorXOutError)]
389fn mssql_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
390 let rt = Runtime::new().expect("Failed to create runtime");
391 let config = mssql_config(conn)?;
392 let tcp = rt.block_on(TcpStream::connect(config.get_addr()))?;
393 tcp.set_nodelay(true)?;
394
395 let mut client = rt.block_on(Client::connect(config, tcp.compat_write()))?;
396
397 let range_query = get_partition_range_query(query, col, &MsSqlDialect {})?;
398 let query_result = rt.block_on(client.query(range_query.as_str(), &[]))?;
399 let row = rt.block_on(query_result.into_row())?.unwrap();
400
401 let col_type = MsSQLTypeSystem::from(&row.columns()[0].column_type());
402 let (min_v, max_v) = match col_type {
403 MsSQLTypeSystem::Tinyint(_) => {
404 let min_v: u8 = row.get(0).unwrap_or(0);
405 let max_v: u8 = row.get(1).unwrap_or(0);
406 (min_v as i64, max_v as i64)
407 }
408 MsSQLTypeSystem::Smallint(_) => {
409 let min_v: i16 = row.get(0).unwrap_or(0);
410 let max_v: i16 = row.get(1).unwrap_or(0);
411 (min_v as i64, max_v as i64)
412 }
413 MsSQLTypeSystem::Int(_) => {
414 let min_v: i32 = row.get(0).unwrap_or(0);
415 let max_v: i32 = row.get(1).unwrap_or(0);
416 (min_v as i64, max_v as i64)
417 }
418 MsSQLTypeSystem::Bigint(_) => {
419 let min_v: i64 = row.get(0).unwrap_or(0);
420 let max_v: i64 = row.get(1).unwrap_or(0);
421 (min_v, max_v)
422 }
423 MsSQLTypeSystem::Intn(_) => {
424 let min_v: IntN = row.get(0).unwrap_or(IntN(0));
425 let max_v: IntN = row.get(1).unwrap_or(IntN(0));
426 (min_v.0, max_v.0)
427 }
428 MsSQLTypeSystem::Float24(_) => {
429 let min_v: f32 = row.get(0).unwrap_or(0.0);
430 let max_v: f32 = row.get(1).unwrap_or(0.0);
431 (min_v as i64, max_v as i64)
432 }
433 MsSQLTypeSystem::Float53(_) => {
434 let min_v: f64 = row.get(0).unwrap_or(0.0);
435 let max_v: f64 = row.get(1).unwrap_or(0.0);
436 (min_v as i64, max_v as i64)
437 }
438 MsSQLTypeSystem::Floatn(_) => {
439 let min_v: FloatN = row.get(0).unwrap_or(FloatN(0.0));
440 let max_v: FloatN = row.get(1).unwrap_or(FloatN(0.0));
441 (min_v.0 as i64, max_v.0 as i64)
442 }
443 _ => throw!(anyhow!(
444 "Partition can only be done on int or float columns"
445 )),
446 };
447
448 (min_v, max_v)
449}
450
451#[cfg(feature = "src_oracle")]
452#[throws(ConnectorXOutError)]
453fn oracle_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
454 let connector = connect_oracle(conn)?;
455 let conn = connector.connect()?;
456 let range_query = get_partition_range_query(query, col, &OracleDialect {})?;
457 let row = conn.query_row(range_query.as_str(), &[])?;
458 let min_v: i64 = row.get(0).unwrap_or(0);
459 let max_v: i64 = row.get(1).unwrap_or(0);
460 (min_v, max_v)
461}
462
463#[cfg(feature = "src_bigquery")]
464#[throws(ConnectorXOutError)] fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
466 let rt = Runtime::new().expect("Failed to create runtime");
467 let url = Url::parse(conn.as_str())?;
468 let sa_key_path = url.path();
469 let client = rt.block_on(gcp_bigquery_client::Client::from_service_account_key_file(
470 sa_key_path,
471 ))?;
472
473 let auth_data = std::fs::read_to_string(sa_key_path)?;
474 let auth_json: serde_json::Value = serde_json::from_str(&auth_data)?;
475 let project_id = auth_json
476 .get("project_id")
477 .ok_or_else(|| anyhow!("Cannot get project_id from auth file"))?
478 .as_str()
479 .ok_or_else(|| anyhow!("Cannot get project_id as string from auth file"))?;
480 let range_query = get_partition_range_query(query, col, &BigQueryDialect {})?;
481
482 let query_result = rt.block_on(client.job().query(
483 project_id,
484 gcp_bigquery_client::model::query_request::QueryRequest::new(range_query.as_str()),
485 ))?;
486 let mut rs = gcp_bigquery_client::model::query_response::ResultSet::new_from_query_response(
487 query_result,
488 );
489 rs.next_row();
490 let min_v = rs.get_i64(0)?.unwrap_or(0);
491 let max_v = rs.get_i64(1)?.unwrap_or(0);
492
493 (min_v, max_v)
494}
495
496#[cfg(feature = "src_trino")]
497#[throws(ConnectorXOutError)]
498fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
499 use prusto::{auth::Auth, ClientBuilder};
500
501 use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult};
502
503 let rt = Runtime::new().expect("Failed to create runtime");
504
505 let username = match conn.username() {
506 "" => "connectorx",
507 username => username,
508 };
509
510 let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned())
511 .port(conn.port().unwrap_or(8080))
512 .ssl(prusto::ssl::Ssl { root_cert: None })
513 .secure(conn.scheme() == "trino+https")
514 .catalog(conn.path_segments().unwrap().last().unwrap_or("hive"));
515
516 let builder = match conn.password() {
517 None => builder,
518 Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))),
519 };
520
521 let client = builder
522 .build()
523 .map_err(|e| anyhow!("Failed to build client: {}", e))?;
524
525 let range_query = get_partition_range_query(query, col, &TrinoDialect {})?;
526 let query_result = rt.block_on(client.get_all::<TrinoPartitionQueryResult>(range_query));
527
528 let query_result = match query_result {
529 Ok(query_result) => Ok(query_result.into_vec()),
530 Err(e) => match e {
531 prusto::error::Error::EmptyData => {
532 Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }])
533 }
534 _ => Err(anyhow!("Failed to get query result: {}", e)),
535 },
536 }?;
537
538 let result = query_result
539 .first()
540 .unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 });
541
542 (result._col0, result._col1)
543}