1use std::{marker::PhantomData, sync::Arc};
2
3use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
4use fehler::{throw, throws};
5use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row};
6use serde_json::Value;
7use sqlparser::dialect::{Dialect, GenericDialect};
8use std::convert::TryFrom;
9use tokio::runtime::Runtime;
10
11use crate::{
12 data_order::DataOrder,
13 errors::ConnectorXError,
14 sources::Produce,
15 sql::{count_query, limit1_query, CXQuery},
16};
17
18pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem};
19use urlencoding::decode;
20
21use super::{PartitionParser, Source, SourcePartition};
22
23use anyhow::anyhow;
24
25pub mod errors;
26pub mod typesystem;
27
28#[throws(TrinoSourceError)]
29fn get_total_rows(rt: Arc<Runtime>, client: Arc<Client>, query: &CXQuery<String>) -> usize {
30 let cquery = count_query(query, &TrinoDialect {})?;
31
32 let row = rt
33 .block_on(client.get_all::<Row>(cquery.to_string()))
34 .map_err(TrinoSourceError::PrustoError)?
35 .split()
36 .1[0]
37 .clone();
38
39 let value = row
40 .value()
41 .first()
42 .ok_or_else(|| anyhow!("Trino count dataset is empty"))?;
43
44 value
45 .as_i64()
46 .ok_or_else(|| anyhow!("Trino cannot parse i64"))? as usize
47}
48
49#[derive(Presto, Debug)]
50pub struct TrinoPartitionQueryResult {
51 pub _col0: i64,
52 pub _col1: i64,
53}
54
55#[derive(Debug)]
56pub struct TrinoDialect {}
57
58impl Dialect for TrinoDialect {
60 fn is_identifier_start(&self, ch: char) -> bool {
61 ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
62 }
63
64 fn is_identifier_part(&self, ch: char) -> bool {
65 ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
66 }
67}
68
69pub struct TrinoSource {
70 client: Arc<Client>,
71 rt: Arc<Runtime>,
72 origin_query: Option<String>,
73 queries: Vec<CXQuery<String>>,
74 names: Vec<String>,
75 schema: Vec<TrinoTypeSystem>,
76}
77
78impl TrinoSource {
79 #[throws(TrinoSourceError)]
80 pub fn new(rt: Arc<Runtime>, conn: &str) -> Self {
81 let decoded_conn = decode(conn)?.into_owned();
82
83 let url = decoded_conn
84 .parse::<url::Url>()
85 .map_err(TrinoSourceError::UrlParseError)?;
86
87 let username = match url.username() {
88 "" => "connectorx",
89 username => username,
90 };
91
92 let no_verify = url
93 .query_pairs()
94 .any(|(k, v)| k == "verify" && v == "false");
95
96 let builder = ClientBuilder::new(username, url.host().unwrap().to_owned())
97 .port(url.port().unwrap_or(8080))
98 .ssl(prusto::ssl::Ssl { root_cert: None })
99 .no_verify(no_verify)
100 .secure(url.scheme() == "trino+https")
101 .catalog(url.path_segments().unwrap().last().unwrap_or("hive"));
102
103 let builder = match url.password() {
104 None => builder,
105 Some(password) => {
106 builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned())))
107 }
108 };
109
110 let client = builder.build().map_err(TrinoSourceError::PrustoError)?;
111
112 Self {
113 client: Arc::new(client),
114 rt,
115 origin_query: None,
116 queries: vec![],
117 names: vec![],
118 schema: vec![],
119 }
120 }
121}
122
123impl Source for TrinoSource
124where
125 TrinoSourcePartition: SourcePartition<TypeSystem = TrinoTypeSystem, Error = TrinoSourceError>,
126{
127 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
128 type TypeSystem = TrinoTypeSystem;
129 type Partition = TrinoSourcePartition;
130 type Error = TrinoSourceError;
131
132 #[throws(TrinoSourceError)]
133 fn set_data_order(&mut self, data_order: DataOrder) {
134 if !matches!(data_order, DataOrder::RowMajor) {
135 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
136 }
137 }
138
139 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
140 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
141 }
142
143 fn set_origin_query(&mut self, query: Option<String>) {
144 self.origin_query = query;
145 }
146
147 #[throws(TrinoSourceError)]
148 fn fetch_metadata(&mut self) {
149 assert!(!self.queries.is_empty());
150
151 let first_query = &self.queries[0];
152 let cxq = limit1_query(first_query, &GenericDialect {})?;
153
154 let dataset: DataSet<Row> = self
155 .rt
156 .block_on(self.client.get_all::<Row>(cxq.to_string()))
157 .map_err(TrinoSourceError::PrustoError)?;
158
159 let schema = dataset.split().0;
160
161 for (name, t) in schema {
162 self.names.push(name.clone());
163 self.schema.push(TrinoTypeSystem::try_from(t.clone())?);
164 }
165 }
166
167 #[throws(TrinoSourceError)]
168 fn result_rows(&mut self) -> Option<usize> {
169 match &self.origin_query {
170 Some(q) => {
171 let cxq = CXQuery::Naked(q.clone());
172 let nrows = get_total_rows(self.rt.clone(), self.client.clone(), &cxq)?;
173 Some(nrows)
174 }
175 None => None,
176 }
177 }
178
179 fn names(&self) -> Vec<String> {
180 self.names.clone()
181 }
182
183 fn schema(&self) -> Vec<Self::TypeSystem> {
184 self.schema.clone()
185 }
186
187 #[throws(TrinoSourceError)]
188 fn partition(self) -> Vec<Self::Partition> {
189 let mut ret = vec![];
190
191 for query in self.queries {
192 ret.push(TrinoSourcePartition::new(
193 self.client.clone(),
194 query,
195 self.schema.clone(),
196 self.rt.clone(),
197 )?);
198 }
199 ret
200 }
201}
202
203pub struct TrinoSourcePartition {
204 client: Arc<Client>,
205 query: CXQuery<String>,
206 schema: Vec<TrinoTypeSystem>,
207 rt: Arc<Runtime>,
208 nrows: usize,
209}
210
211impl TrinoSourcePartition {
212 #[throws(TrinoSourceError)]
213 pub fn new(
214 client: Arc<Client>,
215 query: CXQuery<String>,
216 schema: Vec<TrinoTypeSystem>,
217 rt: Arc<Runtime>,
218 ) -> Self {
219 Self {
220 client,
221 query: query.clone(),
222 schema: schema.to_vec(),
223 rt,
224 nrows: 0,
225 }
226 }
227}
228
229impl SourcePartition for TrinoSourcePartition {
230 type TypeSystem = TrinoTypeSystem;
231 type Parser<'a> = TrinoSourcePartitionParser<'a>;
232 type Error = TrinoSourceError;
233
234 #[throws(TrinoSourceError)]
235 fn result_rows(&mut self) {
236 self.nrows = get_total_rows(self.rt.clone(), self.client.clone(), &self.query)?;
237 }
238
239 #[throws(TrinoSourceError)]
240 fn parser(&mut self) -> Self::Parser<'_> {
241 TrinoSourcePartitionParser::new(
242 self.rt.clone(),
243 self.client.clone(),
244 self.query.clone(),
245 &self.schema,
246 )?
247 }
248
249 fn nrows(&self) -> usize {
250 self.nrows
251 }
252
253 fn ncols(&self) -> usize {
254 self.schema.len()
255 }
256}
257
258pub struct TrinoSourcePartitionParser<'a> {
259 rt: Arc<Runtime>,
260 client: Arc<Client>,
261 next_uri: Option<String>,
262 rows: Vec<Row>,
263 ncols: usize,
264 current_col: usize,
265 current_row: usize,
266 _phantom: &'a PhantomData<DataSet<Row>>,
267}
268
269impl<'a> TrinoSourcePartitionParser<'a> {
270 #[throws(TrinoSourceError)]
271 pub fn new(
272 rt: Arc<Runtime>,
273 client: Arc<Client>,
274 query: CXQuery,
275 schema: &[TrinoTypeSystem],
276 ) -> Self {
277 let results = rt
278 .block_on(client.get::<Row>(query.to_string()))
279 .map_err(TrinoSourceError::PrustoError)?;
280
281 let rows = match results.data_set {
282 Some(x) => x.into_vec(),
283 _ => vec![],
284 };
285
286 Self {
287 rt,
288 client,
289 next_uri: results.next_uri,
290 rows,
291 ncols: schema.len(),
292 current_row: 0,
293 current_col: 0,
294 _phantom: &PhantomData,
295 }
296 }
297
298 #[throws(TrinoSourceError)]
299 fn next_loc(&mut self) -> (usize, usize) {
300 let ret = (self.current_row, self.current_col);
301 self.current_row += (self.current_col + 1) / self.ncols;
302 self.current_col = (self.current_col + 1) % self.ncols;
303 ret
304 }
305}
306
307impl<'a> PartitionParser<'a> for TrinoSourcePartitionParser<'a> {
308 type TypeSystem = TrinoTypeSystem;
309 type Error = TrinoSourceError;
310
311 #[throws(TrinoSourceError)]
312 fn fetch_next(&mut self) -> (usize, bool) {
313 assert!(self.current_col == 0);
314
315 match self.next_uri.clone() {
316 Some(uri) => {
317 let results = self
318 .rt
319 .block_on(self.client.get_next::<Row>(&uri))
320 .map_err(TrinoSourceError::PrustoError)?;
321
322 self.rows = match results.data_set {
323 Some(x) => x.into_vec(),
324 _ => vec![],
325 };
326
327 self.current_row = 0;
328 self.next_uri = results.next_uri;
329
330 (self.rows.len(), false)
331 }
332 None => return (self.rows.len(), true),
333 }
334 }
335}
336
337macro_rules! impl_produce_int {
338 ($($t: ty,)+) => {
339 $(
340 impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
341 type Error = TrinoSourceError;
342
343 #[throws(TrinoSourceError)]
344 fn produce(&'r mut self) -> $t {
345 let (ridx, cidx) = self.next_loc()?;
346 let value = &self.rows[ridx].value()[cidx];
347
348 match value {
349 Value::Number(x) => {
350 if (x.is_i64()) {
351 <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?
352 } else {
353 throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
354 }
355 }
356 _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
357 }
358 }
359 }
360
361 impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
362 type Error = TrinoSourceError;
363
364 #[throws(TrinoSourceError)]
365 fn produce(&'r mut self) -> Option<$t> {
366 let (ridx, cidx) = self.next_loc()?;
367 let value = &self.rows[ridx].value()[cidx];
368
369 match value {
370 Value::Null => None,
371 Value::Number(x) => {
372 if (x.is_i64()) {
373 Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?)
374 } else {
375 throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
376 }
377 }
378 _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
379 }
380 }
381 }
382 )+
383 };
384}
385
386macro_rules! impl_produce_float {
387 ($($t: ty,)+) => {
388 $(
389 impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
390 type Error = TrinoSourceError;
391
392 #[throws(TrinoSourceError)]
393 fn produce(&'r mut self) -> $t {
394 let (ridx, cidx) = self.next_loc()?;
395 let value = &self.rows[ridx].value()[cidx];
396
397 match value {
398 Value::Number(x) => {
399 if (x.is_f64()) {
400 x.as_f64().unwrap() as $t
401 } else {
402 throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
403 }
404 }
405 Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?,
406 _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
407 }
408 }
409 }
410
411 impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
412 type Error = TrinoSourceError;
413
414 #[throws(TrinoSourceError)]
415 fn produce(&'r mut self) -> Option<$t> {
416 let (ridx, cidx) = self.next_loc()?;
417 let value = &self.rows[ridx].value()[cidx];
418
419 match value {
420 Value::Null => None,
421 Value::Number(x) => {
422 if (x.is_f64()) {
423 Some(x.as_f64().unwrap() as $t)
424 } else {
425 throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
426 }
427 }
428 Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?),
429 _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
430 }
431 }
432 }
433 )+
434 };
435}
436
437macro_rules! impl_produce_text {
438 ($($t: ty,)+) => {
439 $(
440 impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
441 type Error = TrinoSourceError;
442
443 #[throws(TrinoSourceError)]
444 fn produce(&'r mut self) -> $t {
445 let (ridx, cidx) = self.next_loc()?;
446 let value = &self.rows[ridx].value()[cidx];
447
448 match value {
449 Value::String(x) => {
450 x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?
451 }
452 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
453 }
454 }
455 }
456
457 impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
458 type Error = TrinoSourceError;
459
460 #[throws(TrinoSourceError)]
461 fn produce(&'r mut self) -> Option<$t> {
462 let (ridx, cidx) = self.next_loc()?;
463 let value = &self.rows[ridx].value()[cidx];
464
465 match value {
466 Value::Null => None,
467 Value::String(x) => {
468 Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?)
469 }
470 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
471 }
472 }
473 }
474 )+
475 };
476}
477
478macro_rules! impl_produce_timestamp {
479 ($($t: ty,)+) => {
480 $(
481 impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
482 type Error = TrinoSourceError;
483
484 #[throws(TrinoSourceError)]
485 fn produce(&'r mut self) -> $t {
486 let (ridx, cidx) = self.next_loc()?;
487 let value = &self.rows[ridx].value()[cidx];
488
489 match value {
490 Value::String(x) => NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?,
491 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
492 }
493 }
494 }
495
496 impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
497 type Error = TrinoSourceError;
498
499 #[throws(TrinoSourceError)]
500 fn produce(&'r mut self) -> Option<$t> {
501 let (ridx, cidx) = self.next_loc()?;
502 let value = &self.rows[ridx].value()[cidx];
503
504 match value {
505 Value::Null => None,
506 Value::String(x) => Some(NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?),
507 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
508 }
509 }
510 }
511 )+
512 };
513}
514
515macro_rules! impl_produce_bool {
516 ($($t: ty,)+) => {
517 $(
518 impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> {
519 type Error = TrinoSourceError;
520
521 #[throws(TrinoSourceError)]
522 fn produce(&'r mut self) -> $t {
523 let (ridx, cidx) = self.next_loc()?;
524 let value = &self.rows[ridx].value()[cidx];
525
526 match value {
527 Value::Bool(x) => *x,
528 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
529 }
530 }
531 }
532
533 impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> {
534 type Error = TrinoSourceError;
535
536 #[throws(TrinoSourceError)]
537 fn produce(&'r mut self) -> Option<$t> {
538 let (ridx, cidx) = self.next_loc()?;
539 let value = &self.rows[ridx].value()[cidx];
540
541 match value {
542 Value::Null => None,
543 Value::Bool(x) => Some(*x),
544 _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value))
545 }
546 }
547 }
548 )+
549 };
550}
551
552impl_produce_bool!(bool,);
553impl_produce_int!(i8, i16, i32, i64,);
554impl_produce_float!(f32, f64,);
555impl_produce_timestamp!(NaiveDateTime,);
556impl_produce_text!(String, char,);
557
558impl<'r, 'a> Produce<'r, NaiveTime> for TrinoSourcePartitionParser<'a> {
559 type Error = TrinoSourceError;
560
561 #[throws(TrinoSourceError)]
562 fn produce(&'r mut self) -> NaiveTime {
563 let (ridx, cidx) = self.next_loc()?;
564 let value = &self.rows[ridx].value()[cidx];
565
566 match value {
567 Value::String(x) => NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| {
568 anyhow!(
569 "Trino cannot parse String at position: ({}, {}): {:?}",
570 ridx,
571 cidx,
572 value
573 )
574 })?,
575 _ => throw!(anyhow!(
576 "Trino unknown value at position: ({}, {}): {:?}",
577 ridx,
578 cidx,
579 value
580 )),
581 }
582 }
583}
584
585impl<'r, 'a> Produce<'r, Option<NaiveTime>> for TrinoSourcePartitionParser<'a> {
586 type Error = TrinoSourceError;
587
588 #[throws(TrinoSourceError)]
589 fn produce(&'r mut self) -> Option<NaiveTime> {
590 let (ridx, cidx) = self.next_loc()?;
591 let value = &self.rows[ridx].value()[cidx];
592
593 match value {
594 Value::Null => None,
595 Value::String(x) => {
596 Some(NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| {
597 anyhow!(
598 "Trino cannot parse Time at position: ({}, {}): {:?}",
599 ridx,
600 cidx,
601 value
602 )
603 })?)
604 }
605 _ => throw!(anyhow!(
606 "Trino unknown value at position: ({}, {}): {:?}",
607 ridx,
608 cidx,
609 value
610 )),
611 }
612 }
613}
614
615impl<'r, 'a> Produce<'r, NaiveDate> for TrinoSourcePartitionParser<'a> {
616 type Error = TrinoSourceError;
617
618 #[throws(TrinoSourceError)]
619 fn produce(&'r mut self) -> NaiveDate {
620 let (ridx, cidx) = self.next_loc()?;
621 let value = &self.rows[ridx].value()[cidx];
622
623 match value {
624 Value::String(x) => NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| {
625 anyhow!(
626 "Trino cannot parse Date at position: ({}, {}): {:?}",
627 ridx,
628 cidx,
629 value
630 )
631 })?,
632 _ => throw!(anyhow!(
633 "Trino unknown value at position: ({}, {}): {:?}",
634 ridx,
635 cidx,
636 value
637 )),
638 }
639 }
640}
641
642impl<'r, 'a> Produce<'r, Option<NaiveDate>> for TrinoSourcePartitionParser<'a> {
643 type Error = TrinoSourceError;
644
645 #[throws(TrinoSourceError)]
646 fn produce(&'r mut self) -> Option<NaiveDate> {
647 let (ridx, cidx) = self.next_loc()?;
648 let value = &self.rows[ridx].value()[cidx];
649
650 match value {
651 Value::Null => None,
652 Value::String(x) => Some(NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| {
653 anyhow!(
654 "Trino cannot parse Date at position: ({}, {}): {:?}",
655 ridx,
656 cidx,
657 value
658 )
659 })?),
660 _ => throw!(anyhow!(
661 "Trino unknown value at position: ({}, {}): {:?}",
662 ridx,
663 cidx,
664 value
665 )),
666 }
667 }
668}