connectorx/sources/oracle/
mod.rs

1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10    data_order::DataOrder,
11    errors::ConnectorXError,
12    sources::{PartitionParser, Produce, Source, SourcePartition},
13    sql::{count_query, limit1_query_oracle, CXQuery},
14    utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23    oracle::{Connector, Row, Statement},
24    OracleConnectionManager,
25};
26use sqlparser::dialect::Dialect;
27use url::Url;
28use urlencoding::decode;
29
30type OracleManager = OracleConnectionManager;
31type OracleConn = PooledConnection<OracleManager>;
32
33#[derive(Debug)]
34pub struct OracleDialect {}
35
36// implementation copy from AnsiDialect
37impl Dialect for OracleDialect {
38    fn is_identifier_start(&self, ch: char) -> bool {
39        ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
40    }
41
42    fn is_identifier_part(&self, ch: char) -> bool {
43        ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
44    }
45}
46
47pub struct OracleSource {
48    pool: Pool<OracleManager>,
49    origin_query: Option<String>,
50    queries: Vec<CXQuery<String>>,
51    names: Vec<String>,
52    schema: Vec<OracleTypeSystem>,
53}
54
55#[throws(OracleSourceError)]
56pub fn connect_oracle(conn: &Url) -> Connector {
57    let user = decode(conn.username())?.into_owned();
58    let password = decode(conn.password().unwrap_or(""))?.into_owned();
59    let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
60
61    let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
62
63    let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
64        host.clone()
65    } else {
66        let port = conn.port().unwrap_or(1521);
67        let path = decode(conn.path())?.into_owned();
68        format!("//{}:{}{}", host, port, path)
69    };
70
71    let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
72    if user.is_empty() && password.is_empty() {
73        debug!("No username or password provided, assuming system auth.");
74        connector.external_auth(true);
75    }
76    connector
77}
78
79impl OracleSource {
80    #[throws(OracleSourceError)]
81    pub fn new(conn: &str, nconn: usize) -> Self {
82        let conn = Url::parse(conn)?;
83        let connector = connect_oracle(&conn)?;
84        let manager = OracleConnectionManager::from_connector(connector);
85        let pool = r2d2::Pool::builder()
86            .max_size(nconn as u32)
87            .build(manager)?;
88
89        Self {
90            pool,
91            origin_query: None,
92            queries: vec![],
93            names: vec![],
94            schema: vec![],
95        }
96    }
97}
98
99impl Source for OracleSource
100where
101    OracleSourcePartition:
102        SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
103{
104    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
105    type Partition = OracleSourcePartition;
106    type TypeSystem = OracleTypeSystem;
107    type Error = OracleSourceError;
108
109    #[throws(OracleSourceError)]
110    fn set_data_order(&mut self, data_order: DataOrder) {
111        if !matches!(data_order, DataOrder::RowMajor) {
112            throw!(ConnectorXError::UnsupportedDataOrder(data_order));
113        }
114    }
115
116    fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
117        self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
118    }
119
120    fn set_origin_query(&mut self, query: Option<String>) {
121        self.origin_query = query;
122    }
123
124    #[throws(OracleSourceError)]
125    fn fetch_metadata(&mut self) {
126        assert!(!self.queries.is_empty());
127
128        let conn = self.pool.get()?;
129        for (i, query) in self.queries.iter().enumerate() {
130            // assuming all the partition queries yield same schema
131            // without rownum = 1, derived type might be wrong
132            // example: select avg(test_int), test_char from test_table group by test_char
133            // -> (NumInt, Char) instead of (NumtFloat, Char)
134            match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
135                Ok(rows) => {
136                    let (names, types) = rows
137                        .column_info()
138                        .iter()
139                        .map(|col| {
140                            (
141                                col.name().to_string(),
142                                OracleTypeSystem::from(col.oracle_type()),
143                            )
144                        })
145                        .unzip();
146                    self.names = names;
147                    self.schema = types;
148                    return;
149                }
150                Err(e) if i == self.queries.len() - 1 => {
151                    // tried the last query but still get an error
152                    debug!("cannot get metadata for '{}': {}", query, e);
153                    throw!(e);
154                }
155                Err(_) => {}
156            }
157        }
158        // tried all queries but all get empty result set
159        let iter = conn.query(self.queries[0].as_str(), &[])?;
160        let (names, types) = iter
161            .column_info()
162            .iter()
163            .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
164            .unzip();
165        self.names = names;
166        self.schema = types;
167    }
168
169    #[throws(OracleSourceError)]
170    fn result_rows(&mut self) -> Option<usize> {
171        match &self.origin_query {
172            Some(q) => {
173                let cxq = CXQuery::Naked(q.clone());
174                let conn = self.pool.get()?;
175
176                let nrows = conn
177                    .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
178                Some(nrows)
179            }
180            None => None,
181        }
182    }
183
184    fn names(&self) -> Vec<String> {
185        self.names.clone()
186    }
187
188    fn schema(&self) -> Vec<Self::TypeSystem> {
189        self.schema.clone()
190    }
191
192    #[throws(OracleSourceError)]
193    fn partition(self) -> Vec<Self::Partition> {
194        let mut ret = vec![];
195        for query in self.queries {
196            let conn = self.pool.get()?;
197            ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
198        }
199        ret
200    }
201}
202
203pub struct OracleSourcePartition {
204    conn: OracleConn,
205    query: CXQuery<String>,
206    schema: Vec<OracleTypeSystem>,
207    nrows: usize,
208    ncols: usize,
209}
210
211impl OracleSourcePartition {
212    pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
213        Self {
214            conn,
215            query: query.clone(),
216            schema: schema.to_vec(),
217            nrows: 0,
218            ncols: schema.len(),
219        }
220    }
221}
222
223impl SourcePartition for OracleSourcePartition {
224    type TypeSystem = OracleTypeSystem;
225    type Parser<'a> = OracleTextSourceParser<'a>;
226    type Error = OracleSourceError;
227
228    #[throws(OracleSourceError)]
229    fn result_rows(&mut self) {
230        self.nrows = self
231            .conn
232            .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
233    }
234
235    #[throws(OracleSourceError)]
236    fn parser(&mut self) -> Self::Parser<'_> {
237        let query = self.query.clone();
238
239        // let iter = self.conn.query(query.as_str(), &[])?;
240        OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
241    }
242
243    fn nrows(&self) -> usize {
244        self.nrows
245    }
246
247    fn ncols(&self) -> usize {
248        self.ncols
249    }
250}
251
252unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
253
254pub struct OracleTextSourceParser<'a> {
255    rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
256    rowbuf: Vec<Row>,
257    ncols: usize,
258    current_col: usize,
259    current_row: usize,
260    is_finished: bool,
261}
262
263impl<'a> OracleTextSourceParser<'a> {
264    #[throws(OracleSourceError)]
265    pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
266        let stmt = conn
267            .statement(query)
268            .prefetch_rows(ORACLE_ARRAY_SIZE)
269            .fetch_array_size(ORACLE_ARRAY_SIZE)
270            .build()?;
271        let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
272            OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
273                DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
274            });
275
276        Self {
277            rows,
278            rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
279            ncols: schema.len(),
280            current_row: 0,
281            current_col: 0,
282            is_finished: false,
283        }
284    }
285
286    #[throws(OracleSourceError)]
287    fn next_loc(&mut self) -> (usize, usize) {
288        let ret = (self.current_row, self.current_col);
289        self.current_row += (self.current_col + 1) / self.ncols;
290        self.current_col = (self.current_col + 1) % self.ncols;
291        ret
292    }
293}
294
295impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
296    type TypeSystem = OracleTypeSystem;
297    type Error = OracleSourceError;
298
299    #[throws(OracleSourceError)]
300    fn fetch_next(&mut self) -> (usize, bool) {
301        assert!(self.current_col == 0);
302        let remaining_rows = self.rowbuf.len() - self.current_row;
303        if remaining_rows > 0 {
304            return (remaining_rows, self.is_finished);
305        } else if self.is_finished {
306            return (0, self.is_finished);
307        }
308
309        if !self.rowbuf.is_empty() {
310            self.rowbuf.drain(..);
311        }
312        for _ in 0..DB_BUFFER_SIZE {
313            if let Some(item) = (*self.rows).next() {
314                self.rowbuf.push(item?);
315            } else {
316                self.is_finished = true;
317                break;
318            }
319        }
320        self.current_row = 0;
321        self.current_col = 0;
322        (self.rowbuf.len(), self.is_finished)
323    }
324}
325
326macro_rules! impl_produce_text {
327    ($($t: ty,)+) => {
328        $(
329            impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
330                type Error = OracleSourceError;
331
332                #[throws(OracleSourceError)]
333                fn produce(&'r mut self) -> $t {
334                    let (ridx, cidx) = self.next_loc()?;
335                    let res = self.rowbuf[ridx].get(cidx)?;
336                    res
337                }
338            }
339
340            impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
341                type Error = OracleSourceError;
342
343                #[throws(OracleSourceError)]
344                fn produce(&'r mut self) -> Option<$t> {
345                    let (ridx, cidx) = self.next_loc()?;
346                    let res = self.rowbuf[ridx].get(cidx)?;
347                    res
348                }
349            }
350        )+
351    };
352}
353
354impl_produce_text!(
355    i64,
356    f64,
357    String,
358    NaiveDate,
359    NaiveDateTime,
360    DateTime<Utc>,
361    Vec<u8>,
362);