connectorx/sources/oracle/
mod.rs1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10 data_order::DataOrder,
11 errors::ConnectorXError,
12 sources::{PartitionParser, Produce, Source, SourcePartition},
13 sql::{count_query, limit1_query_oracle, CXQuery},
14 utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23 oracle::{Connector, Row, Statement},
24 OracleConnectionManager,
25};
26use rust_decimal::Decimal;
27use sqlparser::dialect::Dialect;
28use url::Url;
29use urlencoding::decode;
30
31type OracleManager = OracleConnectionManager;
32type OracleConn = PooledConnection<OracleManager>;
33
34#[derive(Debug)]
35pub struct OracleDialect {}
36
37impl Dialect for OracleDialect {
39 fn is_identifier_start(&self, ch: char) -> bool {
40 ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
41 }
42
43 fn is_identifier_part(&self, ch: char) -> bool {
44 ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
45 }
46}
47
48pub struct OracleSource {
49 pool: Pool<OracleManager>,
50 origin_query: Option<String>,
51 queries: Vec<CXQuery<String>>,
52 names: Vec<String>,
53 schema: Vec<OracleTypeSystem>,
54 current_schema: Option<String>,
55}
56
57#[throws(OracleSourceError)]
58pub fn connect_oracle(conn: &Url) -> Connector {
59 let user = decode(conn.username())?.into_owned();
60 let password = decode(conn.password().unwrap_or(""))?.into_owned();
61 let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
62
63 let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
64
65 let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
66 host.clone()
67 } else {
68 let port = conn.port().unwrap_or(1521);
69 let path = decode(conn.path())?.into_owned();
70 format!("//{}:{}{}", host, port, path)
71 };
72
73 let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
74 if user.is_empty() && password.is_empty() {
75 debug!("No username or password provided, assuming system auth.");
76 connector.external_auth(true);
77 }
78 connector
79}
80
81impl OracleSource {
82 #[throws(OracleSourceError)]
83 pub fn new(conn: &str, nconn: usize) -> Self {
84 let conn = Url::parse(conn)?;
85 let connector = connect_oracle(&conn)?;
86 let manager = OracleConnectionManager::from_connector(connector);
87 let pool = r2d2::Pool::builder()
88 .max_size(nconn as u32)
89 .build(manager)?;
90
91 let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
92 let current_schema = params.get("schema").cloned();
93
94 Self {
95 pool,
96 origin_query: None,
97 queries: vec![],
98 names: vec![],
99 schema: vec![],
100 current_schema,
101 }
102 }
103 pub fn get_conn(&self) -> Result<OracleConn, OracleSourceError> {
104 let conn = self.pool.get()?;
105 if let Some(schema) = &self.current_schema {
106 conn.set_current_schema(schema)?;
107 }
108 Ok(conn)
109 }
110}
111
112impl Source for OracleSource
113where
114 OracleSourcePartition:
115 SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
116{
117 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
118 type Partition = OracleSourcePartition;
119 type TypeSystem = OracleTypeSystem;
120 type Error = OracleSourceError;
121
122 #[throws(OracleSourceError)]
123 fn set_data_order(&mut self, data_order: DataOrder) {
124 if !matches!(data_order, DataOrder::RowMajor) {
125 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
126 }
127 }
128
129 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
130 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
131 }
132
133 fn set_origin_query(&mut self, query: Option<String>) {
134 self.origin_query = query;
135 }
136
137 #[throws(OracleSourceError)]
138 fn fetch_metadata(&mut self) {
139 assert!(!self.queries.is_empty());
140
141 let conn = self.get_conn()?;
142 for (i, query) in self.queries.iter().enumerate() {
143 match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
148 Ok(rows) => {
149 let (names, types) = rows
150 .column_info()
151 .iter()
152 .map(|col| {
153 (
154 col.name().to_string(),
155 OracleTypeSystem::from(col.oracle_type()),
156 )
157 })
158 .unzip();
159 self.names = names;
160 self.schema = types;
161 return;
162 }
163 Err(e) if i == self.queries.len() - 1 => {
164 debug!("cannot get metadata for '{}': {}", query, e);
166 throw!(e);
167 }
168 Err(_) => {}
169 }
170 }
171 let iter = conn.query(self.queries[0].as_str(), &[])?;
173 let (names, types) = iter
174 .column_info()
175 .iter()
176 .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
177 .unzip();
178 self.names = names;
179 self.schema = types;
180 }
181
182 #[throws(OracleSourceError)]
183 fn result_rows(&mut self) -> Option<usize> {
184 match &self.origin_query {
185 Some(q) => {
186 let cxq = CXQuery::Naked(q.clone());
187 let conn = self.get_conn()?;
188
189 let nrows = conn
190 .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
191 Some(nrows)
192 }
193 None => None,
194 }
195 }
196
197 fn names(&self) -> Vec<String> {
198 self.names.clone()
199 }
200
201 fn schema(&self) -> Vec<Self::TypeSystem> {
202 self.schema.clone()
203 }
204
205 #[throws(OracleSourceError)]
206 fn partition(self) -> Vec<Self::Partition> {
207 let mut ret = vec![];
208 for query in &self.queries {
209 let conn = self.get_conn()?;
210 ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
211 }
212 ret
213 }
214}
215
216pub struct OracleSourcePartition {
217 conn: OracleConn,
218 query: CXQuery<String>,
219 schema: Vec<OracleTypeSystem>,
220 nrows: usize,
221 ncols: usize,
222}
223
224impl OracleSourcePartition {
225 pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
226 Self {
227 conn,
228 query: query.clone(),
229 schema: schema.to_vec(),
230 nrows: 0,
231 ncols: schema.len(),
232 }
233 }
234}
235
236impl SourcePartition for OracleSourcePartition {
237 type TypeSystem = OracleTypeSystem;
238 type Parser<'a> = OracleTextSourceParser<'a>;
239 type Error = OracleSourceError;
240
241 #[throws(OracleSourceError)]
242 fn result_rows(&mut self) {
243 self.nrows = self
244 .conn
245 .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
246 }
247
248 #[throws(OracleSourceError)]
249 fn parser(&mut self) -> Self::Parser<'_> {
250 let query = self.query.clone();
251
252 OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
254 }
255
256 fn nrows(&self) -> usize {
257 self.nrows
258 }
259
260 fn ncols(&self) -> usize {
261 self.ncols
262 }
263}
264
265unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
266
267pub struct OracleTextSourceParser<'a> {
268 rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
269 rowbuf: Vec<Row>,
270 ncols: usize,
271 current_col: usize,
272 current_row: usize,
273 is_finished: bool,
274}
275
276impl<'a> OracleTextSourceParser<'a> {
277 #[throws(OracleSourceError)]
278 pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
279 let stmt = conn
280 .statement(query)
281 .prefetch_rows(ORACLE_ARRAY_SIZE)
282 .fetch_array_size(ORACLE_ARRAY_SIZE)
283 .build()?;
284 let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
285 OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
286 DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
287 });
288
289 Self {
290 rows,
291 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
292 ncols: schema.len(),
293 current_row: 0,
294 current_col: 0,
295 is_finished: false,
296 }
297 }
298
299 #[throws(OracleSourceError)]
300 fn next_loc(&mut self) -> (usize, usize) {
301 let ret = (self.current_row, self.current_col);
302 self.current_row += (self.current_col + 1) / self.ncols;
303 self.current_col = (self.current_col + 1) % self.ncols;
304 ret
305 }
306}
307
308impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
309 type TypeSystem = OracleTypeSystem;
310 type Error = OracleSourceError;
311
312 #[throws(OracleSourceError)]
313 fn fetch_next(&mut self) -> (usize, bool) {
314 assert!(self.current_col == 0);
315 let remaining_rows = self.rowbuf.len() - self.current_row;
316 if remaining_rows > 0 {
317 return (remaining_rows, self.is_finished);
318 } else if self.is_finished {
319 return (0, self.is_finished);
320 }
321
322 if !self.rowbuf.is_empty() {
323 self.rowbuf.drain(..);
324 }
325 for _ in 0..DB_BUFFER_SIZE {
326 if let Some(item) = (*self.rows).next() {
327 self.rowbuf.push(item?);
328 } else {
329 self.is_finished = true;
330 break;
331 }
332 }
333 self.current_row = 0;
334 self.current_col = 0;
335 (self.rowbuf.len(), self.is_finished)
336 }
337}
338
339macro_rules! impl_produce_text {
340 ($($t: ty,)+) => {
341 $(
342 impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
343 type Error = OracleSourceError;
344
345 #[throws(OracleSourceError)]
346 fn produce(&'r mut self) -> $t {
347 let (ridx, cidx) = self.next_loc()?;
348 let res = self.rowbuf[ridx].get(cidx)?;
349 res
350 }
351 }
352
353 impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
354 type Error = OracleSourceError;
355
356 #[throws(OracleSourceError)]
357 fn produce(&'r mut self) -> Option<$t> {
358 let (ridx, cidx) = self.next_loc()?;
359 let res = self.rowbuf[ridx].get(cidx)?;
360 res
361 }
362 }
363 )+
364 };
365}
366
367impl_produce_text!(
368 i64,
369 f64,
370 String,
371 NaiveDate,
372 NaiveDateTime,
373 DateTime<Utc>,
374 Vec<u8>,
375);
376
377impl<'r, 'a> Produce<'r, Decimal> for OracleTextSourceParser<'a> {
379 type Error = OracleSourceError;
380
381 #[throws(OracleSourceError)]
382 fn produce(&'r mut self) -> Decimal {
383 let (ridx, cidx) = self.next_loc()?;
384 let s: String = self.rowbuf[ridx].get(cidx)?;
385 let res = s.parse::<Decimal>()?;
386 res
387 }
388}
389
390impl<'r, 'a> Produce<'r, Option<Decimal>> for OracleTextSourceParser<'a> {
391 type Error = OracleSourceError;
392
393 #[throws(OracleSourceError)]
394 fn produce(&'r mut self) -> Option<Decimal> {
395 let (ridx, cidx) = self.next_loc()?;
396 let s: Option<String> = self.rowbuf[ridx].get(cidx)?;
397 match s {
398 Some(val) => Some(val.parse::<Decimal>()?),
399 None => None,
400 }
401 }
402}