connectorx/
sql.rs

1use crate::errors::ConnectorXError;
2#[cfg(feature = "src_oracle")]
3use crate::sources::oracle::OracleDialect;
4use fehler::{throw, throws};
5use log::{debug, trace, warn};
6use sqlparser::ast::{
7    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Query, Select,
8    SelectItem, SetExpr, Statement, TableAlias, TableFactor, TableWithJoins, Value,
9    WildcardAdditionalOptions,
10};
11use sqlparser::dialect::Dialect;
12use sqlparser::parser::Parser;
13#[cfg(feature = "src_oracle")]
14use std::any::Any;
15
16#[derive(Debug, Clone)]
17pub enum CXQuery<Q = String> {
18    Naked(Q),   // The query directly comes from the user
19    Wrapped(Q), // The user query is already wrapped in a subquery
20}
21
22impl<Q: std::fmt::Display> std::fmt::Display for CXQuery<Q> {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            CXQuery::Naked(q) => write!(f, "{}", q),
26            CXQuery::Wrapped(q) => write!(f, "{}", q),
27        }
28    }
29}
30
31impl<Q: AsRef<str>> CXQuery<Q> {
32    pub fn as_str(&self) -> &str {
33        match self {
34            CXQuery::Naked(q) => q.as_ref(),
35            CXQuery::Wrapped(q) => q.as_ref(),
36        }
37    }
38}
39
40impl From<&str> for CXQuery {
41    fn from(s: &str) -> CXQuery<String> {
42        CXQuery::Naked(s.to_string())
43    }
44}
45
46impl From<&&str> for CXQuery {
47    fn from(s: &&str) -> CXQuery<String> {
48        CXQuery::Naked(s.to_string())
49    }
50}
51
52impl From<&String> for CXQuery {
53    fn from(s: &String) -> CXQuery {
54        CXQuery::Naked(s.clone())
55    }
56}
57
58impl From<&CXQuery> for CXQuery {
59    fn from(q: &CXQuery) -> CXQuery {
60        q.clone()
61    }
62}
63
64impl CXQuery<String> {
65    pub fn naked<Q: AsRef<str>>(q: Q) -> Self {
66        CXQuery::Naked(q.as_ref().to_string())
67    }
68}
69
70impl<Q: AsRef<str>> AsRef<str> for CXQuery<Q> {
71    fn as_ref(&self) -> &str {
72        match self {
73            CXQuery::Naked(q) => q.as_ref(),
74            CXQuery::Wrapped(q) => q.as_ref(),
75        }
76    }
77}
78
79impl<Q> CXQuery<Q> {
80    pub fn map<F, U>(&self, f: F) -> CXQuery<U>
81    where
82        F: Fn(&Q) -> U,
83    {
84        match self {
85            CXQuery::Naked(q) => CXQuery::Naked(f(q)),
86            CXQuery::Wrapped(q) => CXQuery::Wrapped(f(q)),
87        }
88    }
89}
90
91impl<Q, E> CXQuery<Result<Q, E>> {
92    pub fn result(self) -> Result<CXQuery<Q>, E> {
93        match self {
94            CXQuery::Naked(q) => q.map(CXQuery::Naked),
95            CXQuery::Wrapped(q) => q.map(CXQuery::Wrapped),
96        }
97    }
98}
99
100// wrap a query into a derived table
101fn wrap_query(
102    query: &mut Query,
103    projection: Vec<SelectItem>,
104    selection: Option<Expr>,
105    tmp_tab_name: &str,
106) -> Statement {
107    let with = query.with.clone();
108    query.with = None;
109    let alias = if tmp_tab_name.is_empty() {
110        None
111    } else {
112        Some(TableAlias {
113            name: Ident {
114                value: tmp_tab_name.into(),
115                quote_style: None,
116            },
117            columns: vec![],
118        })
119    };
120    Statement::Query(Box::new(Query {
121        with,
122        locks: vec![],
123        body: Box::new(SetExpr::Select(Box::new(Select {
124            distinct: None,
125            top: None,
126            projection,
127            from: vec![TableWithJoins {
128                relation: TableFactor::Derived {
129                    lateral: false,
130                    subquery: Box::new(query.clone()),
131                    alias,
132                },
133                joins: vec![],
134            }],
135            lateral_views: vec![],
136            selection,
137            group_by: vec![],
138            cluster_by: vec![],
139            distribute_by: vec![],
140            sort_by: vec![],
141            having: None,
142            into: None,
143            named_window: vec![],
144            qualify: None,
145        }))),
146        order_by: vec![],
147        limit: None,
148        offset: None,
149        fetch: None,
150    }))
151}
152
153trait StatementExt {
154    fn as_query(&self) -> Option<&Query>;
155}
156
157impl StatementExt for Statement {
158    fn as_query(&self) -> Option<&Query> {
159        match self {
160            Statement::Query(q) => Some(q),
161            _ => None,
162        }
163    }
164}
165
166trait QueryExt {
167    fn as_select_mut(&mut self) -> Option<&mut Select>;
168}
169
170impl QueryExt for Query {
171    fn as_select_mut(&mut self) -> Option<&mut Select> {
172        match *self.body {
173            SetExpr::Select(ref mut select) => Some(select),
174            _ => None,
175        }
176    }
177}
178
179#[throws(ConnectorXError)]
180pub fn count_query<T: Dialect>(sql: &CXQuery<String>, dialect: &T) -> CXQuery<String> {
181    trace!("Incoming query: {}", sql);
182
183    const COUNT_TMP_TAB_NAME: &str = "CXTMPTAB_COUNT";
184
185    #[allow(unused_mut)]
186    let mut table_alias = COUNT_TMP_TAB_NAME;
187
188    // HACK: Some dialect (e.g. Oracle) does not support "AS" for alias
189    #[cfg(feature = "src_oracle")]
190    if dialect.type_id() == (OracleDialect {}.type_id()) {
191        // table_alias = "";
192        return CXQuery::Wrapped(format!(
193            "SELECT COUNT(*) FROM ({}) {}",
194            sql.as_str(),
195            COUNT_TMP_TAB_NAME
196        ));
197    }
198
199    let tsql = match sql.map(|sql| Parser::parse_sql(dialect, sql)).result() {
200        Ok(ast) => {
201            let projection = vec![SelectItem::UnnamedExpr(Expr::Function(Function {
202                name: ObjectName(vec![Ident {
203                    value: "count".to_string(),
204                    quote_style: None,
205                }]),
206                args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)],
207                over: None,
208                distinct: false,
209                order_by: vec![],
210                special: false,
211            }))];
212            let ast_count: Statement = match ast {
213                CXQuery::Naked(ast) => {
214                    if ast.len() != 1 {
215                        throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
216                    }
217                    let mut query = ast[0]
218                        .as_query()
219                        .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
220                        .clone();
221                    if query.offset.is_none() {
222                        query.order_by = vec![]; // mssql offset must appear with order by
223                    }
224                    let select = query
225                        .as_select_mut()
226                        .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?;
227                    select.sort_by = vec![];
228                    wrap_query(&mut query, projection, None, table_alias)
229                }
230                CXQuery::Wrapped(ast) => {
231                    if ast.len() != 1 {
232                        throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
233                    }
234                    let mut query = ast[0]
235                        .as_query()
236                        .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
237                        .clone();
238                    let select = query
239                        .as_select_mut()
240                        .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?;
241                    select.projection = projection;
242                    Statement::Query(Box::new(query))
243                }
244            };
245            format!("{}", ast_count)
246        }
247        Err(e) => {
248            warn!("parser error: {:?}, manually compose query string", e);
249            format!(
250                "SELECT COUNT(*) FROM ({}) as {}",
251                sql.as_str(),
252                COUNT_TMP_TAB_NAME
253            )
254        }
255    };
256
257    debug!("Transformed count query: {}", tsql);
258    CXQuery::Wrapped(tsql)
259}
260
261#[throws(ConnectorXError)]
262pub fn limit1_query<T: Dialect>(sql: &CXQuery<String>, dialect: &T) -> CXQuery<String> {
263    trace!("Incoming query: {}", sql);
264
265    let sql = match Parser::parse_sql(dialect, sql.as_str()) {
266        Ok(mut ast) => {
267            if ast.len() != 1 {
268                throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
269            }
270
271            match &mut ast[0] {
272                Statement::Query(q) => {
273                    q.limit = Some(Expr::Value(Value::Number("1".to_string(), false)));
274                }
275                _ => throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string())),
276            };
277
278            format!("{}", ast[0])
279        }
280        Err(e) => {
281            warn!("parser error: {:?}, manually compose query string", e);
282            format!("{} LIMIT 1", sql.as_str())
283        }
284    };
285
286    debug!("Transformed limit 1 query: {}", sql);
287    CXQuery::Wrapped(sql)
288}
289
290#[throws(ConnectorXError)]
291#[cfg(feature = "src_oracle")]
292pub fn limit1_query_oracle(sql: &CXQuery<String>) -> CXQuery<String> {
293    trace!("Incoming oracle query: {}", sql);
294
295    CXQuery::Wrapped(format!("SELECT * FROM ({}) WHERE rownum = 1", sql))
296
297    // let ast = Parser::parse_sql(&OracleDialect {}, sql.as_str())?;
298    // if ast.len() != 1 {
299    //     throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
300    // }
301    // let ast_part: Statement;
302    // let mut query = ast[0]
303    //     .as_query()
304    //     .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
305    //     .clone();
306
307    // let selection = Expr::BinaryOp {
308    //     left: Box::new(Expr::CompoundIdentifier(vec![Ident {
309    //         value: "rownum".to_string(),
310    //         quote_style: None,
311    //     }])),
312    //     op: BinaryOperator::Eq,
313    //     right: Box::new(Expr::Value(Value::Number("1".to_string(), false))),
314    // };
315    // ast_part = wrap_query(&mut query, vec![SelectItem::Wildcard], Some(selection), "");
316
317    // let tsql = format!("{}", ast_part);
318    // debug!("Transformed limit 1 query: {}", tsql);
319    // CXQuery::Wrapped(tsql)
320}
321
322#[throws(ConnectorXError)]
323pub fn single_col_partition_query<T: Dialect>(
324    sql: &str,
325    col: &str,
326    lower: i64,
327    upper: i64,
328    dialect: &T,
329) -> String {
330    trace!("Incoming query: {}", sql);
331    const PART_TMP_TAB_NAME: &str = "CXTMPTAB_PART";
332
333    #[allow(unused_mut)]
334    let mut table_alias = PART_TMP_TAB_NAME;
335    #[allow(unused_mut)]
336    let mut cid = Box::new(Expr::CompoundIdentifier(vec![
337        Ident {
338            value: PART_TMP_TAB_NAME.to_string(),
339            quote_style: None,
340        },
341        Ident {
342            value: col.to_string(),
343            quote_style: None,
344        },
345    ]));
346
347    // HACK: Some dialect (e.g. Oracle) does not support "AS" for alias
348    #[cfg(feature = "src_oracle")]
349    if dialect.type_id() == (OracleDialect {}.type_id()) {
350        return format!("SELECT * FROM ({}) CXTMPTAB_PART WHERE CXTMPTAB_PART.{} >= {} AND CXTMPTAB_PART.{} < {}", sql, col, lower, col, upper);
351        // table_alias = "";
352        // cid = Box::new(Expr::Identifier(Ident {
353        //     value: col.to_string(),
354        //     quote_style: None,
355        // }));
356    }
357
358    let tsql = match Parser::parse_sql(dialect, sql) {
359        Ok(ast) => {
360            if ast.len() != 1 {
361                throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
362            }
363
364            let mut query = ast[0]
365                .as_query()
366                .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
367                .clone();
368
369            let select = query
370                .as_select_mut()
371                .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
372                .clone();
373
374            let ast_part: Statement;
375
376            let lb = Expr::BinaryOp {
377                left: Box::new(Expr::Value(Value::Number(lower.to_string(), false))),
378                op: BinaryOperator::LtEq,
379                right: cid.clone(),
380            };
381
382            let ub = Expr::BinaryOp {
383                left: cid,
384                op: BinaryOperator::Lt,
385                right: Box::new(Expr::Value(Value::Number(upper.to_string(), false))),
386            };
387
388            let selection = Expr::BinaryOp {
389                left: Box::new(lb),
390                op: BinaryOperator::And,
391                right: Box::new(ub),
392            };
393
394            if query.limit.is_none() && select.top.is_none() && !query.order_by.is_empty() {
395                // order by in a partition query does not make sense because partition is unordered.
396                // clear the order by beceause mssql does not support order by in a derived table.
397                // also order by in the derived table does not make any difference.
398                query.order_by.clear();
399            }
400
401            ast_part = wrap_query(
402                &mut query,
403                vec![SelectItem::Wildcard(WildcardAdditionalOptions::default())],
404                Some(selection),
405                table_alias,
406            );
407            format!("{}", ast_part)
408        }
409        Err(e) => {
410            warn!("parser error: {:?}, manually compose query string", e);
411            format!("SELECT * FROM ({}) AS CXTMPTAB_PART WHERE CXTMPTAB_PART.{} >= {} AND CXTMPTAB_PART.{} < {}", sql, col, lower, col, upper)
412        }
413    };
414
415    debug!("Transformed single column partition query: {}", tsql);
416    tsql
417}
418
419#[throws(ConnectorXError)]
420pub fn get_partition_range_query<T: Dialect>(sql: &str, col: &str, dialect: &T) -> String {
421    trace!("Incoming query: {}", sql);
422    const RANGE_TMP_TAB_NAME: &str = "CXTMPTAB_RANGE";
423
424    #[allow(unused_mut)]
425    let mut table_alias = RANGE_TMP_TAB_NAME;
426    #[allow(unused_mut)]
427    let mut args = vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
428        Expr::CompoundIdentifier(vec![
429            Ident {
430                value: RANGE_TMP_TAB_NAME.to_string(),
431                quote_style: None,
432            },
433            Ident {
434                value: col.to_string(),
435                quote_style: None,
436            },
437        ]),
438    ))];
439
440    // HACK: Some dialect (e.g. Oracle) does not support "AS" for alias
441    #[cfg(feature = "src_oracle")]
442    if dialect.type_id() == (OracleDialect {}.type_id()) {
443        return format!(
444            "SELECT MIN({}.{}) as min, MAX({}.{}) as max FROM ({}) {}",
445            RANGE_TMP_TAB_NAME, col, RANGE_TMP_TAB_NAME, col, sql, RANGE_TMP_TAB_NAME
446        );
447        // table_alias = "";
448        // args = vec![FunctionArg::Unnamed(Expr::Identifier(Ident {
449        //     value: col.to_string(),
450        //     quote_style: None,
451        // }))];
452    }
453
454    let tsql = match Parser::parse_sql(dialect, sql) {
455        Ok(ast) => {
456            if ast.len() != 1 {
457                throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
458            }
459
460            let mut query = ast[0]
461                .as_query()
462                .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
463                .clone();
464            let ast_range: Statement;
465
466            if query.limit.is_none() && query.offset.is_none() {
467                query.order_by = vec![]; // only omit orderby when there is no limit and offset in the query
468            }
469            let projection = vec![
470                SelectItem::UnnamedExpr(Expr::Function(Function {
471                    name: ObjectName(vec![Ident {
472                        value: "min".to_string(),
473                        quote_style: None,
474                    }]),
475                    args: args.clone(),
476                    over: None,
477                    distinct: false,
478                    order_by: vec![],
479                    special: false,
480                })),
481                SelectItem::UnnamedExpr(Expr::Function(Function {
482                    name: ObjectName(vec![Ident {
483                        value: "max".to_string(),
484                        quote_style: None,
485                    }]),
486                    args,
487                    over: None,
488                    distinct: false,
489                    order_by: vec![],
490                    special: false,
491                })),
492            ];
493            ast_range = wrap_query(&mut query, projection, None, table_alias);
494            format!("{}", ast_range)
495        }
496        Err(e) => {
497            warn!("parser error: {:?}, manually compose query string", e);
498            format!(
499                "SELECT MIN({}.{}) as min, MAX({}.{}) as max FROM ({}) AS {}",
500                RANGE_TMP_TAB_NAME, col, RANGE_TMP_TAB_NAME, col, sql, RANGE_TMP_TAB_NAME
501            )
502        }
503    };
504
505    debug!("Transformed partition range query: {}", tsql);
506    tsql
507}
508
509#[throws(ConnectorXError)]
510pub fn get_partition_range_query_sep<T: Dialect>(
511    sql: &str,
512    col: &str,
513    dialect: &T,
514) -> (String, String) {
515    trace!("Incoming query: {}", sql);
516    const RANGE_TMP_TAB_NAME: &str = "CXTMPTAB_RANGE";
517
518    let (sql_min, sql_max) = match Parser::parse_sql(dialect, sql) {
519        Ok(ast) => {
520            if ast.len() != 1 {
521                throw!(ConnectorXError::SqlQueryNotSupported(sql.to_string()));
522            }
523
524            let mut query = ast[0]
525                .as_query()
526                .ok_or_else(|| ConnectorXError::SqlQueryNotSupported(sql.to_string()))?
527                .clone();
528
529            let ast_range_min: Statement;
530            let ast_range_max: Statement;
531
532            query.order_by = vec![];
533            let min_proj = vec![SelectItem::UnnamedExpr(Expr::Function(Function {
534                name: ObjectName(vec![Ident {
535                    value: "min".to_string(),
536                    quote_style: None,
537                }]),
538                args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
539                    Expr::CompoundIdentifier(vec![
540                        Ident {
541                            value: RANGE_TMP_TAB_NAME.to_string(),
542                            quote_style: None,
543                        },
544                        Ident {
545                            value: col.to_string(),
546                            quote_style: None,
547                        },
548                    ]),
549                ))],
550                over: None,
551                distinct: false,
552                order_by: vec![],
553                special: false,
554            }))];
555            let max_proj = vec![SelectItem::UnnamedExpr(Expr::Function(Function {
556                name: ObjectName(vec![Ident {
557                    value: "max".to_string(),
558                    quote_style: None,
559                }]),
560                args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
561                    Expr::CompoundIdentifier(vec![
562                        Ident {
563                            value: RANGE_TMP_TAB_NAME.into(),
564                            quote_style: None,
565                        },
566                        Ident {
567                            value: col.into(),
568                            quote_style: None,
569                        },
570                    ]),
571                ))],
572                over: None,
573                distinct: false,
574                order_by: vec![],
575                special: false,
576            }))];
577            ast_range_min = wrap_query(&mut query.clone(), min_proj, None, RANGE_TMP_TAB_NAME);
578            ast_range_max = wrap_query(&mut query, max_proj, None, RANGE_TMP_TAB_NAME);
579            (format!("{}", ast_range_min), format!("{}", ast_range_max))
580        }
581        Err(e) => {
582            warn!("parser error: {:?}, manually compose query string", e);
583            (
584                format!(
585                    "SELECT MIN({}.{}) as min FROM ({}) AS {}",
586                    RANGE_TMP_TAB_NAME, col, sql, RANGE_TMP_TAB_NAME
587                ),
588                format!(
589                    "SELECT MAX({}.{}) as max FROM ({}) AS {}",
590                    RANGE_TMP_TAB_NAME, col, sql, RANGE_TMP_TAB_NAME
591                ),
592            )
593        }
594    };
595    debug!(
596        "Transformed separated partition range query: {}, {}",
597        sql_min, sql_max
598    );
599    (sql_min, sql_max)
600}