connectorx/sources/trino/
mod.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
4use fehler::{throw, throws};
5use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row};
6use serde_json::Value;
7use sqlparser::dialect::{Dialect, GenericDialect};
8use std::convert::TryFrom;
9use tokio::runtime::Runtime;
10
11use crate::{
12    data_order::DataOrder,
13    errors::ConnectorXError,
14    sources::Produce,
15    sql::{count_query, limit1_query, CXQuery},
16};
17
18pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem};
19use urlencoding::decode;
20
21use super::{PartitionParser, Source, SourcePartition};
22
23use anyhow::anyhow;
24
25pub mod errors;
26pub mod typesystem;
27
28#[throws(TrinoSourceError)]
29fn get_total_rows(rt: Arc<Runtime>, client: Arc<Client>, query: &CXQuery<String>) -> usize {
30    let cquery = count_query(query, &TrinoDialect {})?;
31
32    let row = rt
33        .block_on(client.get_all::<Row>(cquery.to_string()))
34        .map_err(TrinoSourceError::PrustoError)?
35        .split()
36        .1[0]
37        .clone();
38
39    let value = row
40        .value()
41        .first()
42        .ok_or_else(|| anyhow!("Trino count dataset is empty"))?;
43
44    value
45        .as_i64()
46        .ok_or_else(|| anyhow!("Trino cannot parse i64"))? as usize
47}
48
49#[derive(Presto, Debug)]
50pub struct TrinoPartitionQueryResult {
51    pub _col0: i64,
52    pub _col1: i64,
53}
54
55#[derive(Debug)]
56pub struct TrinoDialect {}
57
58// implementation copy from AnsiDialect
59impl Dialect for TrinoDialect {
60    fn is_identifier_start(&self, ch: char) -> bool {
61        ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
62    }
63
64    fn is_identifier_part(&self, ch: char) -> bool {
65        ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
66    }
67}
68
69pub struct TrinoSource {
70    client: Arc<Client>,
71    rt: Arc<Runtime>,
72    origin_query: Option<String>,
73    queries: Vec<CXQuery<String>>,
74    names: Vec<String>,
75    schema: Vec<TrinoTypeSystem>,
76}
77
78impl TrinoSource {
79    #[throws(TrinoSourceError)]
80    pub fn new(rt: Arc<Runtime>, conn: &str) -> Self {
81        let decoded_conn = decode(conn)?.into_owned();
82
83        let url = decoded_conn
84            .parse::<url::Url>()
85            .map_err(TrinoSourceError::UrlParseError)?;
86
87        let username = match url.username() {
88            "" => "connectorx",
89            username => username,
90        };
91
92        let no_verify = url
93            .query_pairs()
94            .any(|(k, v)| k == "verify" && v == "false");
95
96        let builder = ClientBuilder::new(username, url.host().unwrap().to_owned())
97            .port(url.port().unwrap_or(8080))
98            .ssl(prusto::ssl::Ssl { root_cert: None })
99            .no_verify(no_verify)
100            .secure(url.scheme() == "trino+https")
101            .catalog(url.path_segments().unwrap().last().unwrap_or("hive"));
102
103        let builder = match url.password() {
104            None => builder,
105            Some(password) => {
106                builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned())))
107            }
108        };
109
110        let client = builder.build().map_err(TrinoSourceError::PrustoError)?;
111
112        Self {
113            client: Arc::new(client),
114            rt,
115            origin_query: None,
116            queries: vec![],
117            names: vec![],
118            schema: vec![],
119        }
120    }
121}
122
123impl Source for TrinoSource
124where
125    TrinoSourcePartition: SourcePartition<TypeSystem = TrinoTypeSystem, Error = TrinoSourceError>,
126{
127    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
128    type TypeSystem = TrinoTypeSystem;
129    type Partition = TrinoSourcePartition;
130    type Error = TrinoSourceError;
131
132    #[throws(TrinoSourceError)]
133    fn set_data_order(&mut self, data_order: DataOrder) {
134        if !matches!(data_order, DataOrder::RowMajor) {
135            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
136        }
137    }
138
139    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
140        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
141    }
142
143    fn set_origin_query(&mut self, query: Option<String>) {
144        self.origin_query = query;
145    }
146
147    #[throws(TrinoSourceError)]
148    fn fetch_metadata(&mut self) {
149        assert!(!self.queries.is_empty());
150
151        let first_query = &self.queries[0];
152        let cxq = limit1_query(first_query, &GenericDialect {})?;
153
154        let dataset: DataSet<Row> = self
155            .rt
156            .block_on(self.client.get_all::<Row>(cxq.to_string()))
157            .map_err(TrinoSourceError::PrustoError)?;
158
159        let schema = dataset.split().0;
160
161        for (name, t) in schema {
162            self.names.push(name.clone());
163            self.schema.push(TrinoTypeSystem::try_from(t.clone())?);
164        }
165    }
166
167    #[throws(TrinoSourceError)]
168    fn result_rows(&mut self) -> Option<usize> {
169        match &self.origin_query {
170            Some(q) => {
171                let cxq = CXQuery::Naked(q.clone());
172                let nrows = get_total_rows(self.rt.clone(), self.client.clone(), &cxq)?;
173                Some(nrows)
174            }
175            None => None,
176        }
177    }
178
179    fn names(&self) -> Vec<String> {
180        self.names.clone()
181    }
182
183    fn schema(&self) -> Vec<Self::TypeSystem> {
184        self.schema.clone()
185    }
186
187    #[throws(TrinoSourceError)]
188    fn partition(self) -> Vec<Self::Partition> {
189        let mut ret = vec![];
190
191        for query in self.queries {
192            ret.push(TrinoSourcePartition::new(
193                self.client.clone(),
194                query,
195                self.schema.clone(),
196                self.rt.clone(),
197            )?);
198        }
199        ret
200    }
201}
202
203pub struct TrinoSourcePartition {
204    client: Arc<Client>,
205    query: CXQuery<String>,
206    schema: Vec<TrinoTypeSystem>,
207    rt: Arc<Runtime>,
208    nrows: usize,
209}
210
211impl TrinoSourcePartition {
212    #[throws(TrinoSourceError)]
213    pub fn new(
214        client: Arc<Client>,
215        query: CXQuery<String>,
216        schema: Vec<TrinoTypeSystem>,
217        rt: Arc<Runtime>,
218    ) -> Self {
219        Self {
220            client,
221            query: query.clone(),
222            schema: schema.to_vec(),
223            rt,
224            nrows: 0,
225        }
226    }
227}
228
229impl SourcePartition for TrinoSourcePartition {
230    type TypeSystem = TrinoTypeSystem;
231    type Parser<'a> = TrinoSourcePartitionParser<'a>;
232    type Error = TrinoSourceError;
233
234    #[throws(TrinoSourceError)]
235    fn result_rows(&mut self) {
236        self.nrows = get_total_rows(self.rt.clone(), self.client.clone(), &self.query)?;
237    }
238
239    #[throws(TrinoSourceError)]
240    fn parser(&mut self) -> Self::Parser<'_> {
241        TrinoSourcePartitionParser::new(
242            self.rt.clone(),
243            self.client.clone(),
244            self.query.clone(),
245            &self.schema,
246        )?
247    }
248
249    fn nrows(&self) -> usize {
250        self.nrows
251    }
252
253    fn ncols(&self) -> usize {
254        self.schema.len()
255    }
256}
257
258pub struct TrinoSourcePartitionParser<'a> {
259    rt: Arc<Runtime>,
260    client: Arc<Client>,
261    next_uri: Option<String>,
262    rows: Vec<Row>,
263    ncols: usize,
264    current_col: usize,
265    current_row: usize,
266    _phantom: &'a PhantomData<DataSet<Row>>,
267}
268
269impl<'a> TrinoSourcePartitionParser<'a> {
270    #[throws(TrinoSourceError)]
271    pub fn new(
272        rt: Arc<Runtime>,
273        client: Arc<Client>,
274        query: CXQuery,
275        schema: &[TrinoTypeSystem],
276    ) -> Self {
277        let results = rt
278            .block_on(client.get::<Row>(query.to_string()))
279            .map_err(TrinoSourceError::PrustoError)?;
280
281        let rows = match results.data_set {
282            Some(x) => x.into_vec(),
283            _ => vec![],
284        };
285
286        Self {
287            rt,
288            client,
289            next_uri: results.next_uri,
290            rows,
291            ncols: schema.len(),
292            current_row: 0,
293            current_col: 0,
294            _phantom: &PhantomData,
295        }
296    }
297
298    #[throws(TrinoSourceError)]
299    fn next_loc(&mut self) -> (usize, usize) {
300        let ret = (self.current_row, self.current_col);
301        self.current_row += (self.current_col + 1) / self.ncols;
302        self.current_col = (self.current_col + 1) % self.ncols;
303        ret
304    }
305}
306
307impl<'a> PartitionParser<'a> for TrinoSourcePartitionParser<'a> {
308    type TypeSystem = TrinoTypeSystem;
309    type Error = TrinoSourceError;
310
311    #[throws(TrinoSourceError)]
312    fn fetch_next(&mut self) -> (usize, bool) {
313        assert!(self.current_col == 0);
314
315        match self.next_uri.clone() {
316            Some(uri) => {
317                let results = self
318                    .rt
319                    .block_on(self.client.get_next::<Row>(&uri))
320                    .map_err(TrinoSourceError::PrustoError)?;
321
322                self.rows = match results.data_set {
323                    Some(x) => x.into_vec(),
324                    _ => vec![],
325                };
326
327                self.current_row = 0;
328                self.next_uri = results.next_uri;
329
330                (self.rows.len(), false)
331            }
332            None => return (self.rows.len(), true),
333        }
334    }
335}
336
337macro_rules! impl_produce_int {
338    ($($t: ty,)+) => {
339        $(
340            impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
341                type Error = TrinoSourceError;
342
343                #[throws(TrinoSourceError)]
344                fn produce(&'r mut self) -> $t {
345                    let (ridx, cidx) = self.next_loc()?;
346                    let value = &self.rows[ridx].value()[cidx];
347
348                    match value {
349                        Value::Number(x) => {
350                            if (x.is_i64()) {
351                                <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?
352                            } else {
353                                throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
354                            }
355                        }
356                        _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
357                    }
358                }
359            }
360
361            impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
362                type Error = TrinoSourceError;
363
364                #[throws(TrinoSourceError)]
365                fn produce(&'r mut self) -> Option<$t> {
366                    let (ridx, cidx) = self.next_loc()?;
367                    let value = &self.rows[ridx].value()[cidx];
368
369                    match value {
370                        Value::Null => None,
371                        Value::Number(x) => {
372                            if (x.is_i64()) {
373                                Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?)
374                            } else {
375                                throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
376                            }
377                        }
378                        _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
379                    }
380                }
381            }
382        )+
383    };
384}
385
386macro_rules! impl_produce_float {
387    ($($t: ty,)+) => {
388        $(
389            impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
390                type Error = TrinoSourceError;
391
392                #[throws(TrinoSourceError)]
393                fn produce(&'r mut self) -> $t {
394                    let (ridx, cidx) = self.next_loc()?;
395                    let value = &self.rows[ridx].value()[cidx];
396
397                    match value {
398                        Value::Number(x) => {
399                            if (x.is_f64()) {
400                                x.as_f64().unwrap() as $t
401                            } else {
402                                throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
403                            }
404                        }
405                        Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?,
406                        _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
407                    }
408                }
409            }
410
411            impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
412                type Error = TrinoSourceError;
413
414                #[throws(TrinoSourceError)]
415                fn produce(&'r mut self) -> Option<$t> {
416                    let (ridx, cidx) = self.next_loc()?;
417                    let value = &self.rows[ridx].value()[cidx];
418
419                    match value {
420                        Value::Null => None,
421                        Value::Number(x) => {
422                            if (x.is_f64()) {
423                                Some(x.as_f64().unwrap() as $t)
424                            } else {
425                                throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
426                            }
427                        }
428                        Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?),
429                        _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
430                    }
431                }
432            }
433        )+
434    };
435}
436
437macro_rules! impl_produce_text {
438    ($($t: ty,)+) => {
439        $(
440            impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
441                type Error = TrinoSourceError;
442
443                #[throws(TrinoSourceError)]
444                fn produce(&'r mut self) -> $t {
445                    let (ridx, cidx) = self.next_loc()?;
446                    let value = &self.rows[ridx].value()[cidx];
447
448                    match value {
449                        Value::String(x) => {
450                            x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?
451                        }
452                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
453                    }
454                }
455            }
456
457            impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
458                type Error = TrinoSourceError;
459
460                #[throws(TrinoSourceError)]
461                fn produce(&'r mut self) -> Option<$t> {
462                    let (ridx, cidx) = self.next_loc()?;
463                    let value = &self.rows[ridx].value()[cidx];
464
465                    match value {
466                        Value::Null => None,
467                        Value::String(x) => {
468                            Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?)
469                        }
470                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
471                    }
472                }
473            }
474        )+
475    };
476}
477
478macro_rules! impl_produce_timestamp {
479    ($($t: ty,)+) => {
480        $(
481            impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
482                type Error = TrinoSourceError;
483
484                #[throws(TrinoSourceError)]
485                fn produce(&'r mut self) -> $t {
486                    let (ridx, cidx) = self.next_loc()?;
487                    let value = &self.rows[ridx].value()[cidx];
488
489                    match value {
490                        Value::String(x) => NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?,
491                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
492                    }
493                }
494            }
495
496            impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
497                type Error = TrinoSourceError;
498
499                #[throws(TrinoSourceError)]
500                fn produce(&'r mut self) -> Option<$t> {
501                    let (ridx, cidx) = self.next_loc()?;
502                    let value = &self.rows[ridx].value()[cidx];
503
504                    match value {
505                        Value::Null => None,
506                        Value::String(x) => Some(NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?),
507                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
508                    }
509                }
510            }
511        )+
512    };
513}
514
515macro_rules! impl_produce_bool {
516    ($($t: ty,)+) => {
517        $(
518            impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
519                type Error = TrinoSourceError;
520
521                #[throws(TrinoSourceError)]
522                fn produce(&'r mut self) -> $t {
523                    let (ridx, cidx) = self.next_loc()?;
524                    let value = &self.rows[ridx].value()[cidx];
525
526                    match value {
527                        Value::Bool(x) => *x,
528                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
529                    }
530                }
531            }
532
533            impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
534                type Error = TrinoSourceError;
535
536                #[throws(TrinoSourceError)]
537                fn produce(&'r mut self) -> Option<$t> {
538                    let (ridx, cidx) = self.next_loc()?;
539                    let value = &self.rows[ridx].value()[cidx];
540
541                    match value {
542                        Value::Null => None,
543                        Value::Bool(x) => Some(*x),
544                        _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
545                    }
546                }
547            }
548        )+
549    };
550}
551
552impl_produce_bool!(bool,);
553impl_produce_int!(i8, i16, i32, i64,);
554impl_produce_float!(f32, f64,);
555impl_produce_timestamp!(NaiveDateTime,);
556impl_produce_text!(String, char,);
557
558impl<'r, 'a> Produce<'r, NaiveTime> for TrinoSourcePartitionParser<'a> {
559    type Error = TrinoSourceError;
560
561    #[throws(TrinoSourceError)]
562    fn produce(&'r mut self) -> NaiveTime {
563        let (ridx, cidx) = self.next_loc()?;
564        let value = &self.rows[ridx].value()[cidx];
565
566        match value {
567            Value::String(x) => NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| {
568                anyhow!(
569                    "Trino cannot parse String at position: ({}, {}): {:?}",
570                    ridx,
571                    cidx,
572                    value
573                )
574            })?,
575            _ => throw!(anyhow!(
576                "Trino unknown value at position: ({}, {}): {:?}",
577                ridx,
578                cidx,
579                value
580            )),
581        }
582    }
583}
584
585impl<'r, 'a> Produce<'r, Option<NaiveTime>> for TrinoSourcePartitionParser<'a> {
586    type Error = TrinoSourceError;
587
588    #[throws(TrinoSourceError)]
589    fn produce(&'r mut self) -> Option<NaiveTime> {
590        let (ridx, cidx) = self.next_loc()?;
591        let value = &self.rows[ridx].value()[cidx];
592
593        match value {
594            Value::Null => None,
595            Value::String(x) => {
596                Some(NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| {
597                    anyhow!(
598                        "Trino cannot parse Time at position: ({}, {}): {:?}",
599                        ridx,
600                        cidx,
601                        value
602                    )
603                })?)
604            }
605            _ => throw!(anyhow!(
606                "Trino unknown value at position: ({}, {}): {:?}",
607                ridx,
608                cidx,
609                value
610            )),
611        }
612    }
613}
614
615impl<'r, 'a> Produce<'r, NaiveDate> for TrinoSourcePartitionParser<'a> {
616    type Error = TrinoSourceError;
617
618    #[throws(TrinoSourceError)]
619    fn produce(&'r mut self) -> NaiveDate {
620        let (ridx, cidx) = self.next_loc()?;
621        let value = &self.rows[ridx].value()[cidx];
622
623        match value {
624            Value::String(x) => NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| {
625                anyhow!(
626                    "Trino cannot parse Date at position: ({}, {}): {:?}",
627                    ridx,
628                    cidx,
629                    value
630                )
631            })?,
632            _ => throw!(anyhow!(
633                "Trino unknown value at position: ({}, {}): {:?}",
634                ridx,
635                cidx,
636                value
637            )),
638        }
639    }
640}
641
642impl<'r, 'a> Produce<'r, Option<NaiveDate>> for TrinoSourcePartitionParser<'a> {
643    type Error = TrinoSourceError;
644
645    #[throws(TrinoSourceError)]
646    fn produce(&'r mut self) -> Option<NaiveDate> {
647        let (ridx, cidx) = self.next_loc()?;
648        let value = &self.rows[ridx].value()[cidx];
649
650        match value {
651            Value::Null => None,
652            Value::String(x) => Some(NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| {
653                anyhow!(
654                    "Trino cannot parse Date at position: ({}, {}): {:?}",
655                    ridx,
656                    cidx,
657                    value
658                )
659            })?),
660            _ => throw!(anyhow!(
661                "Trino unknown value at position: ({}, {}): {:?}",
662                ridx,
663                cidx,
664                value
665            )),
666        }
667    }
668}