connectorx/sources/oracle/
mod.rs

1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10    data_order::DataOrder,
11    errors::ConnectorXError,
12    sources::{PartitionParser, Produce, Source, SourcePartition},
13    sql::{count_query, limit1_query_oracle, CXQuery},
14    utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23    oracle::{Connector, Row, Statement},
24    OracleConnectionManager,
25};
26use rust_decimal::Decimal;
27use sqlparser::dialect::Dialect;
28use url::Url;
29use urlencoding::decode;
30
31type OracleManager = OracleConnectionManager;
32type OracleConn = PooledConnection<OracleManager>;
33
34#[derive(Debug)]
35pub struct OracleDialect {}
36
37// implementation copy from AnsiDialect
38impl Dialect for OracleDialect {
39    fn is_identifier_start(&self, ch: char) -> bool {
40        ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
41    }
42
43    fn is_identifier_part(&self, ch: char) -> bool {
44        ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
45    }
46}
47
48pub struct OracleSource {
49    pool: Pool<OracleManager>,
50    origin_query: Option<String>,
51    queries: Vec<CXQuery<String>>,
52    names: Vec<String>,
53    schema: Vec<OracleTypeSystem>,
54    current_schema: Option<String>,
55}
56
57#[throws(OracleSourceError)]
58pub fn connect_oracle(conn: &Url) -> Connector {
59    let user = decode(conn.username())?.into_owned();
60    let password = decode(conn.password().unwrap_or(""))?.into_owned();
61    let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
62
63    let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
64
65    let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
66        host.clone()
67    } else {
68        let port = conn.port().unwrap_or(1521);
69        let path = decode(conn.path())?.into_owned();
70        format!("//{}:{}{}", host, port, path)
71    };
72
73    let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
74    if user.is_empty() && password.is_empty() {
75        debug!("No username or password provided, assuming system auth.");
76        connector.external_auth(true);
77    }
78    connector
79}
80
81impl OracleSource {
82    #[throws(OracleSourceError)]
83    pub fn new(conn: &str, nconn: usize) -> Self {
84        let conn = Url::parse(conn)?;
85        let connector = connect_oracle(&conn)?;
86        let manager = OracleConnectionManager::from_connector(connector);
87        let pool = r2d2::Pool::builder()
88            .max_size(nconn as u32)
89            .build(manager)?;
90
91        let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
92        let current_schema = params.get("schema").cloned();
93
94        Self {
95            pool,
96            origin_query: None,
97            queries: vec![],
98            names: vec![],
99            schema: vec![],
100            current_schema,
101        }
102    }
103    pub fn get_conn(&self) -> Result<OracleConn, OracleSourceError> {
104        let conn = self.pool.get()?;
105        if let Some(schema) = &self.current_schema {
106            conn.set_current_schema(schema)?;
107        }
108        Ok(conn)
109    }
110}
111
112impl Source for OracleSource
113where
114    OracleSourcePartition:
115        SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
116{
117    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
118    type Partition = OracleSourcePartition;
119    type TypeSystem = OracleTypeSystem;
120    type Error = OracleSourceError;
121
122    #[throws(OracleSourceError)]
123    fn set_data_order(&mut self, data_order: DataOrder) {
124        if !matches!(data_order, DataOrder::RowMajor) {
125            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
126        }
127    }
128
129    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
130        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
131    }
132
133    fn set_origin_query(&mut self, query: Option<String>) {
134        self.origin_query = query;
135    }
136
137    #[throws(OracleSourceError)]
138    fn fetch_metadata(&mut self) {
139        assert!(!self.queries.is_empty());
140
141        let conn = self.get_conn()?;
142        for (i, query) in self.queries.iter().enumerate() {
143            // assuming all the partition queries yield same schema
144            // without rownum = 1, derived type might be wrong
145            // example: select avg(test_int), test_char from test_table group by test_char
146            // -> (NumInt, Char) instead of (NumtFloat, Char)
147            match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
148                Ok(rows) => {
149                    let (names, types) = rows
150                        .column_info()
151                        .iter()
152                        .map(|col| {
153                            (
154                                col.name().to_string(),
155                                OracleTypeSystem::from(col.oracle_type()),
156                            )
157                        })
158                        .unzip();
159                    self.names = names;
160                    self.schema = types;
161                    return;
162                }
163                Err(e) if i == self.queries.len() - 1 => {
164                    // tried the last query but still get an error
165                    debug!("cannot get metadata for '{}': {}", query, e);
166                    throw!(e);
167                }
168                Err(_) => {}
169            }
170        }
171        // tried all queries but all get empty result set
172        let iter = conn.query(self.queries[0].as_str(), &[])?;
173        let (names, types) = iter
174            .column_info()
175            .iter()
176            .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
177            .unzip();
178        self.names = names;
179        self.schema = types;
180    }
181
182    #[throws(OracleSourceError)]
183    fn result_rows(&mut self) -> Option<usize> {
184        match &self.origin_query {
185            Some(q) => {
186                let cxq = CXQuery::Naked(q.clone());
187                let conn = self.get_conn()?;
188
189                let nrows = conn
190                    .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
191                Some(nrows)
192            }
193            None => None,
194        }
195    }
196
197    fn names(&self) -> Vec<String> {
198        self.names.clone()
199    }
200
201    fn schema(&self) -> Vec<Self::TypeSystem> {
202        self.schema.clone()
203    }
204
205    #[throws(OracleSourceError)]
206    fn partition(self) -> Vec<Self::Partition> {
207        let mut ret = vec![];
208        for query in &self.queries {
209            let conn = self.get_conn()?;
210            ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
211        }
212        ret
213    }
214}
215
216pub struct OracleSourcePartition {
217    conn: OracleConn,
218    query: CXQuery<String>,
219    schema: Vec<OracleTypeSystem>,
220    nrows: usize,
221    ncols: usize,
222}
223
224impl OracleSourcePartition {
225    pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
226        Self {
227            conn,
228            query: query.clone(),
229            schema: schema.to_vec(),
230            nrows: 0,
231            ncols: schema.len(),
232        }
233    }
234}
235
236impl SourcePartition for OracleSourcePartition {
237    type TypeSystem = OracleTypeSystem;
238    type Parser<'a> = OracleTextSourceParser<'a>;
239    type Error = OracleSourceError;
240
241    #[throws(OracleSourceError)]
242    fn result_rows(&mut self) {
243        self.nrows = self
244            .conn
245            .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
246    }
247
248    #[throws(OracleSourceError)]
249    fn parser(&mut self) -> Self::Parser<'_> {
250        let query = self.query.clone();
251
252        // let iter = self.conn.query(query.as_str(), &[])?;
253        OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
254    }
255
256    fn nrows(&self) -> usize {
257        self.nrows
258    }
259
260    fn ncols(&self) -> usize {
261        self.ncols
262    }
263}
264
265unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
266
267pub struct OracleTextSourceParser<'a> {
268    rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
269    rowbuf: Vec<Row>,
270    ncols: usize,
271    current_col: usize,
272    current_row: usize,
273    is_finished: bool,
274}
275
276impl<'a> OracleTextSourceParser<'a> {
277    #[throws(OracleSourceError)]
278    pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
279        let stmt = conn
280            .statement(query)
281            .prefetch_rows(ORACLE_ARRAY_SIZE)
282            .fetch_array_size(ORACLE_ARRAY_SIZE)
283            .build()?;
284        let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
285            OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
286                DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
287            });
288
289        Self {
290            rows,
291            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
292            ncols: schema.len(),
293            current_row: 0,
294            current_col: 0,
295            is_finished: false,
296        }
297    }
298
299    #[throws(OracleSourceError)]
300    fn next_loc(&mut self) -> (usize, usize) {
301        let ret = (self.current_row, self.current_col);
302        self.current_row += (self.current_col + 1) / self.ncols;
303        self.current_col = (self.current_col + 1) % self.ncols;
304        ret
305    }
306}
307
308impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
309    type TypeSystem = OracleTypeSystem;
310    type Error = OracleSourceError;
311
312    #[throws(OracleSourceError)]
313    fn fetch_next(&mut self) -> (usize, bool) {
314        assert!(self.current_col == 0);
315        let remaining_rows = self.rowbuf.len() - self.current_row;
316        if remaining_rows > 0 {
317            return (remaining_rows, self.is_finished);
318        } else if self.is_finished {
319            return (0, self.is_finished);
320        }
321
322        if !self.rowbuf.is_empty() {
323            self.rowbuf.drain(..);
324        }
325        for _ in 0..DB_BUFFER_SIZE {
326            if let Some(item) = (*self.rows).next() {
327                self.rowbuf.push(item?);
328            } else {
329                self.is_finished = true;
330                break;
331            }
332        }
333        self.current_row = 0;
334        self.current_col = 0;
335        (self.rowbuf.len(), self.is_finished)
336    }
337}
338
339macro_rules! impl_produce_text {
340    ($($t: ty,)+) => {
341        $(
342            impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
343                type Error = OracleSourceError;
344
345                #[throws(OracleSourceError)]
346                fn produce(&'r mut self) -> $t {
347                    let (ridx, cidx) = self.next_loc()?;
348                    let res = self.rowbuf[ridx].get(cidx)?;
349                    res
350                }
351            }
352
353            impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
354                type Error = OracleSourceError;
355
356                #[throws(OracleSourceError)]
357                fn produce(&'r mut self) -> Option<$t> {
358                    let (ridx, cidx) = self.next_loc()?;
359                    let res = self.rowbuf[ridx].get(cidx)?;
360                    res
361                }
362            }
363        )+
364    };
365}
366
367impl_produce_text!(
368    i64,
369    f64,
370    String,
371    NaiveDate,
372    NaiveDateTime,
373    DateTime<Utc>,
374    Vec<u8>,
375);
376
377// Manual implementation for Decimal since Oracle doesn't support it directly via FromSql
378impl<'r, 'a> Produce<'r, Decimal> for OracleTextSourceParser<'a> {
379    type Error = OracleSourceError;
380
381    #[throws(OracleSourceError)]
382    fn produce(&'r mut self) -> Decimal {
383        let (ridx, cidx) = self.next_loc()?;
384        let s: String = self.rowbuf[ridx].get(cidx)?;
385        let res = s.parse::<Decimal>()?;
386        res
387    }
388}
389
390impl<'r, 'a> Produce<'r, Option<Decimal>> for OracleTextSourceParser<'a> {
391    type Error = OracleSourceError;
392
393    #[throws(OracleSourceError)]
394    fn produce(&'r mut self) -> Option<Decimal> {
395        let (ridx, cidx) = self.next_loc()?;
396        let s: Option<String> = self.rowbuf[ridx].get(cidx)?;
397        match s {
398            Some(val) => Some(val.parse::<Decimal>()?),
399            None => None,
400        }
401    }
402}