connectorx/sources/oracle/
mod.rs

1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10    data_order::DataOrder,
11    errors::ConnectorXError,
12    sources::{PartitionParser, Produce, Source, SourcePartition},
13    sql::{count_query, limit1_query_oracle, CXQuery},
14    utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23    oracle::{Connector, Row, Statement},
24    OracleConnectionManager,
25};
26use sqlparser::dialect::Dialect;
27use url::Url;
28use urlencoding::decode;
29
30type OracleManager = OracleConnectionManager;
31type OracleConn = PooledConnection<OracleManager>;
32
33#[derive(Debug)]
34pub struct OracleDialect {}
35
36// implementation copy from AnsiDialect
37impl Dialect for OracleDialect {
38    fn is_identifier_start(&self, ch: char) -> bool {
39        ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
40    }
41
42    fn is_identifier_part(&self, ch: char) -> bool {
43        ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
44    }
45}
46
47pub struct OracleSource {
48    pool: Pool<OracleManager>,
49    origin_query: Option<String>,
50    queries: Vec<CXQuery<String>>,
51    names: Vec<String>,
52    schema: Vec<OracleTypeSystem>,
53    current_schema: Option<String>,
54}
55
56#[throws(OracleSourceError)]
57pub fn connect_oracle(conn: &Url) -> Connector {
58    let user = decode(conn.username())?.into_owned();
59    let password = decode(conn.password().unwrap_or(""))?.into_owned();
60    let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
61
62    let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
63
64    let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
65        host.clone()
66    } else {
67        let port = conn.port().unwrap_or(1521);
68        let path = decode(conn.path())?.into_owned();
69        format!("//{}:{}{}", host, port, path)
70    };
71
72    let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
73    if user.is_empty() && password.is_empty() {
74        debug!("No username or password provided, assuming system auth.");
75        connector.external_auth(true);
76    }
77    connector
78}
79
80impl OracleSource {
81    #[throws(OracleSourceError)]
82    pub fn new(conn: &str, nconn: usize) -> Self {
83        let conn = Url::parse(conn)?;
84        let connector = connect_oracle(&conn)?;
85        let manager = OracleConnectionManager::from_connector(connector);
86        let pool = r2d2::Pool::builder()
87            .max_size(nconn as u32)
88            .build(manager)?;
89
90        let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
91        let current_schema = params.get("schema").cloned();
92
93        Self {
94            pool,
95            origin_query: None,
96            queries: vec![],
97            names: vec![],
98            schema: vec![],
99            current_schema,
100        }
101    }
102    pub fn get_conn(&self) -> Result<OracleConn, OracleSourceError> {
103        let conn = self.pool.get()?;
104        if let Some(schema) = &self.current_schema {
105            conn.set_current_schema(schema)?;
106        }
107        Ok(conn)
108    }
109}
110
111impl Source for OracleSource
112where
113    OracleSourcePartition:
114        SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
115{
116    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
117    type Partition = OracleSourcePartition;
118    type TypeSystem = OracleTypeSystem;
119    type Error = OracleSourceError;
120
121    #[throws(OracleSourceError)]
122    fn set_data_order(&mut self, data_order: DataOrder) {
123        if !matches!(data_order, DataOrder::RowMajor) {
124            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
125        }
126    }
127
128    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
129        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
130    }
131
132    fn set_origin_query(&mut self, query: Option<String>) {
133        self.origin_query = query;
134    }
135
136    #[throws(OracleSourceError)]
137    fn fetch_metadata(&mut self) {
138        assert!(!self.queries.is_empty());
139
140        let conn = self.get_conn()?;
141        for (i, query) in self.queries.iter().enumerate() {
142            // assuming all the partition queries yield same schema
143            // without rownum = 1, derived type might be wrong
144            // example: select avg(test_int), test_char from test_table group by test_char
145            // -> (NumInt, Char) instead of (NumtFloat, Char)
146            match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
147                Ok(rows) => {
148                    let (names, types) = rows
149                        .column_info()
150                        .iter()
151                        .map(|col| {
152                            (
153                                col.name().to_string(),
154                                OracleTypeSystem::from(col.oracle_type()),
155                            )
156                        })
157                        .unzip();
158                    self.names = names;
159                    self.schema = types;
160                    return;
161                }
162                Err(e) if i == self.queries.len() - 1 => {
163                    // tried the last query but still get an error
164                    debug!("cannot get metadata for '{}': {}", query, e);
165                    throw!(e);
166                }
167                Err(_) => {}
168            }
169        }
170        // tried all queries but all get empty result set
171        let iter = conn.query(self.queries[0].as_str(), &[])?;
172        let (names, types) = iter
173            .column_info()
174            .iter()
175            .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
176            .unzip();
177        self.names = names;
178        self.schema = types;
179    }
180
181    #[throws(OracleSourceError)]
182    fn result_rows(&mut self) -> Option<usize> {
183        match &self.origin_query {
184            Some(q) => {
185                let cxq = CXQuery::Naked(q.clone());
186                let conn = self.get_conn()?;
187
188                let nrows = conn
189                    .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
190                Some(nrows)
191            }
192            None => None,
193        }
194    }
195
196    fn names(&self) -> Vec<String> {
197        self.names.clone()
198    }
199
200    fn schema(&self) -> Vec<Self::TypeSystem> {
201        self.schema.clone()
202    }
203
204    #[throws(OracleSourceError)]
205    fn partition(self) -> Vec<Self::Partition> {
206        let mut ret = vec![];
207        for query in &self.queries {
208            let conn = self.get_conn()?;
209            ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
210        }
211        ret
212    }
213}
214
215pub struct OracleSourcePartition {
216    conn: OracleConn,
217    query: CXQuery<String>,
218    schema: Vec<OracleTypeSystem>,
219    nrows: usize,
220    ncols: usize,
221}
222
223impl OracleSourcePartition {
224    pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
225        Self {
226            conn,
227            query: query.clone(),
228            schema: schema.to_vec(),
229            nrows: 0,
230            ncols: schema.len(),
231        }
232    }
233}
234
235impl SourcePartition for OracleSourcePartition {
236    type TypeSystem = OracleTypeSystem;
237    type Parser<'a> = OracleTextSourceParser<'a>;
238    type Error = OracleSourceError;
239
240    #[throws(OracleSourceError)]
241    fn result_rows(&mut self) {
242        self.nrows = self
243            .conn
244            .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
245    }
246
247    #[throws(OracleSourceError)]
248    fn parser(&mut self) -> Self::Parser<'_> {
249        let query = self.query.clone();
250
251        // let iter = self.conn.query(query.as_str(), &[])?;
252        OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
253    }
254
255    fn nrows(&self) -> usize {
256        self.nrows
257    }
258
259    fn ncols(&self) -> usize {
260        self.ncols
261    }
262}
263
264unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
265
266pub struct OracleTextSourceParser<'a> {
267    rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
268    rowbuf: Vec<Row>,
269    ncols: usize,
270    current_col: usize,
271    current_row: usize,
272    is_finished: bool,
273}
274
275impl<'a> OracleTextSourceParser<'a> {
276    #[throws(OracleSourceError)]
277    pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
278        let stmt = conn
279            .statement(query)
280            .prefetch_rows(ORACLE_ARRAY_SIZE)
281            .fetch_array_size(ORACLE_ARRAY_SIZE)
282            .build()?;
283        let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
284            OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
285                DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
286            });
287
288        Self {
289            rows,
290            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
291            ncols: schema.len(),
292            current_row: 0,
293            current_col: 0,
294            is_finished: false,
295        }
296    }
297
298    #[throws(OracleSourceError)]
299    fn next_loc(&mut self) -> (usize, usize) {
300        let ret = (self.current_row, self.current_col);
301        self.current_row += (self.current_col + 1) / self.ncols;
302        self.current_col = (self.current_col + 1) % self.ncols;
303        ret
304    }
305}
306
307impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
308    type TypeSystem = OracleTypeSystem;
309    type Error = OracleSourceError;
310
311    #[throws(OracleSourceError)]
312    fn fetch_next(&mut self) -> (usize, bool) {
313        assert!(self.current_col == 0);
314        let remaining_rows = self.rowbuf.len() - self.current_row;
315        if remaining_rows > 0 {
316            return (remaining_rows, self.is_finished);
317        } else if self.is_finished {
318            return (0, self.is_finished);
319        }
320
321        if !self.rowbuf.is_empty() {
322            self.rowbuf.drain(..);
323        }
324        for _ in 0..DB_BUFFER_SIZE {
325            if let Some(item) = (*self.rows).next() {
326                self.rowbuf.push(item?);
327            } else {
328                self.is_finished = true;
329                break;
330            }
331        }
332        self.current_row = 0;
333        self.current_col = 0;
334        (self.rowbuf.len(), self.is_finished)
335    }
336}
337
338macro_rules! impl_produce_text {
339    ($($t: ty,)+) => {
340        $(
341            impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
342                type Error = OracleSourceError;
343
344                #[throws(OracleSourceError)]
345                fn produce(&'r mut self) -> $t {
346                    let (ridx, cidx) = self.next_loc()?;
347                    let res = self.rowbuf[ridx].get(cidx)?;
348                    res
349                }
350            }
351
352            impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
353                type Error = OracleSourceError;
354
355                #[throws(OracleSourceError)]
356                fn produce(&'r mut self) -> Option<$t> {
357                    let (ridx, cidx) = self.next_loc()?;
358                    let res = self.rowbuf[ridx].get(cidx)?;
359                    res
360                }
361            }
362        )+
363    };
364}
365
366impl_produce_text!(
367    i64,
368    f64,
369    String,
370    NaiveDate,
371    NaiveDateTime,
372    DateTime<Utc>,
373    Vec<u8>,
374);