connectorx/sources/oracle/
mod.rs1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10 data_order::DataOrder,
11 errors::ConnectorXError,
12 sources::{PartitionParser, Produce, Source, SourcePartition},
13 sql::{count_query, limit1_query_oracle, CXQuery},
14 utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23 oracle::{Connector, Row, Statement},
24 OracleConnectionManager,
25};
26use sqlparser::dialect::Dialect;
27use url::Url;
28use urlencoding::decode;
29
30type OracleManager = OracleConnectionManager;
31type OracleConn = PooledConnection<OracleManager>;
32
33#[derive(Debug)]
34pub struct OracleDialect {}
35
36impl Dialect for OracleDialect {
38 fn is_identifier_start(&self, ch: char) -> bool {
39 ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
40 }
41
42 fn is_identifier_part(&self, ch: char) -> bool {
43 ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
44 }
45}
46
47pub struct OracleSource {
48 pool: Pool<OracleManager>,
49 origin_query: Option<String>,
50 queries: Vec<CXQuery<String>>,
51 names: Vec<String>,
52 schema: Vec<OracleTypeSystem>,
53 current_schema: Option<String>,
54}
55
56#[throws(OracleSourceError)]
57pub fn connect_oracle(conn: &Url) -> Connector {
58 let user = decode(conn.username())?.into_owned();
59 let password = decode(conn.password().unwrap_or(""))?.into_owned();
60 let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
61
62 let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
63
64 let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
65 host.clone()
66 } else {
67 let port = conn.port().unwrap_or(1521);
68 let path = decode(conn.path())?.into_owned();
69 format!("//{}:{}{}", host, port, path)
70 };
71
72 let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
73 if user.is_empty() && password.is_empty() {
74 debug!("No username or password provided, assuming system auth.");
75 connector.external_auth(true);
76 }
77 connector
78}
79
80impl OracleSource {
81 #[throws(OracleSourceError)]
82 pub fn new(conn: &str, nconn: usize) -> Self {
83 let conn = Url::parse(conn)?;
84 let connector = connect_oracle(&conn)?;
85 let manager = OracleConnectionManager::from_connector(connector);
86 let pool = r2d2::Pool::builder()
87 .max_size(nconn as u32)
88 .build(manager)?;
89
90 let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
91 let current_schema = params.get("schema").cloned();
92
93 Self {
94 pool,
95 origin_query: None,
96 queries: vec![],
97 names: vec![],
98 schema: vec![],
99 current_schema,
100 }
101 }
102 pub fn get_conn(&self) -> Result<OracleConn, OracleSourceError> {
103 let conn = self.pool.get()?;
104 if let Some(schema) = &self.current_schema {
105 conn.set_current_schema(schema)?;
106 }
107 Ok(conn)
108 }
109}
110
111impl Source for OracleSource
112where
113 OracleSourcePartition:
114 SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
115{
116 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
117 type Partition = OracleSourcePartition;
118 type TypeSystem = OracleTypeSystem;
119 type Error = OracleSourceError;
120
121 #[throws(OracleSourceError)]
122 fn set_data_order(&mut self, data_order: DataOrder) {
123 if !matches!(data_order, DataOrder::RowMajor) {
124 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
125 }
126 }
127
128 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
129 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
130 }
131
132 fn set_origin_query(&mut self, query: Option<String>) {
133 self.origin_query = query;
134 }
135
136 #[throws(OracleSourceError)]
137 fn fetch_metadata(&mut self) {
138 assert!(!self.queries.is_empty());
139
140 let conn = self.get_conn()?;
141 for (i, query) in self.queries.iter().enumerate() {
142 match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
147 Ok(rows) => {
148 let (names, types) = rows
149 .column_info()
150 .iter()
151 .map(|col| {
152 (
153 col.name().to_string(),
154 OracleTypeSystem::from(col.oracle_type()),
155 )
156 })
157 .unzip();
158 self.names = names;
159 self.schema = types;
160 return;
161 }
162 Err(e) if i == self.queries.len() - 1 => {
163 debug!("cannot get metadata for '{}': {}", query, e);
165 throw!(e);
166 }
167 Err(_) => {}
168 }
169 }
170 let iter = conn.query(self.queries[0].as_str(), &[])?;
172 let (names, types) = iter
173 .column_info()
174 .iter()
175 .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
176 .unzip();
177 self.names = names;
178 self.schema = types;
179 }
180
181 #[throws(OracleSourceError)]
182 fn result_rows(&mut self) -> Option<usize> {
183 match &self.origin_query {
184 Some(q) => {
185 let cxq = CXQuery::Naked(q.clone());
186 let conn = self.get_conn()?;
187
188 let nrows = conn
189 .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
190 Some(nrows)
191 }
192 None => None,
193 }
194 }
195
196 fn names(&self) -> Vec<String> {
197 self.names.clone()
198 }
199
200 fn schema(&self) -> Vec<Self::TypeSystem> {
201 self.schema.clone()
202 }
203
204 #[throws(OracleSourceError)]
205 fn partition(self) -> Vec<Self::Partition> {
206 let mut ret = vec![];
207 for query in &self.queries {
208 let conn = self.get_conn()?;
209 ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
210 }
211 ret
212 }
213}
214
215pub struct OracleSourcePartition {
216 conn: OracleConn,
217 query: CXQuery<String>,
218 schema: Vec<OracleTypeSystem>,
219 nrows: usize,
220 ncols: usize,
221}
222
223impl OracleSourcePartition {
224 pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
225 Self {
226 conn,
227 query: query.clone(),
228 schema: schema.to_vec(),
229 nrows: 0,
230 ncols: schema.len(),
231 }
232 }
233}
234
235impl SourcePartition for OracleSourcePartition {
236 type TypeSystem = OracleTypeSystem;
237 type Parser<'a> = OracleTextSourceParser<'a>;
238 type Error = OracleSourceError;
239
240 #[throws(OracleSourceError)]
241 fn result_rows(&mut self) {
242 self.nrows = self
243 .conn
244 .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
245 }
246
247 #[throws(OracleSourceError)]
248 fn parser(&mut self) -> Self::Parser<'_> {
249 let query = self.query.clone();
250
251 OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
253 }
254
255 fn nrows(&self) -> usize {
256 self.nrows
257 }
258
259 fn ncols(&self) -> usize {
260 self.ncols
261 }
262}
263
264unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
265
266pub struct OracleTextSourceParser<'a> {
267 rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
268 rowbuf: Vec<Row>,
269 ncols: usize,
270 current_col: usize,
271 current_row: usize,
272 is_finished: bool,
273}
274
275impl<'a> OracleTextSourceParser<'a> {
276 #[throws(OracleSourceError)]
277 pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
278 let stmt = conn
279 .statement(query)
280 .prefetch_rows(ORACLE_ARRAY_SIZE)
281 .fetch_array_size(ORACLE_ARRAY_SIZE)
282 .build()?;
283 let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
284 OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
285 DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
286 });
287
288 Self {
289 rows,
290 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
291 ncols: schema.len(),
292 current_row: 0,
293 current_col: 0,
294 is_finished: false,
295 }
296 }
297
298 #[throws(OracleSourceError)]
299 fn next_loc(&mut self) -> (usize, usize) {
300 let ret = (self.current_row, self.current_col);
301 self.current_row += (self.current_col + 1) / self.ncols;
302 self.current_col = (self.current_col + 1) % self.ncols;
303 ret
304 }
305}
306
307impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
308 type TypeSystem = OracleTypeSystem;
309 type Error = OracleSourceError;
310
311 #[throws(OracleSourceError)]
312 fn fetch_next(&mut self) -> (usize, bool) {
313 assert!(self.current_col == 0);
314 let remaining_rows = self.rowbuf.len() - self.current_row;
315 if remaining_rows > 0 {
316 return (remaining_rows, self.is_finished);
317 } else if self.is_finished {
318 return (0, self.is_finished);
319 }
320
321 if !self.rowbuf.is_empty() {
322 self.rowbuf.drain(..);
323 }
324 for _ in 0..DB_BUFFER_SIZE {
325 if let Some(item) = (*self.rows).next() {
326 self.rowbuf.push(item?);
327 } else {
328 self.is_finished = true;
329 break;
330 }
331 }
332 self.current_row = 0;
333 self.current_col = 0;
334 (self.rowbuf.len(), self.is_finished)
335 }
336}
337
338macro_rules! impl_produce_text {
339 ($($t: ty,)+) => {
340 $(
341 impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
342 type Error = OracleSourceError;
343
344 #[throws(OracleSourceError)]
345 fn produce(&'r mut self) -> $t {
346 let (ridx, cidx) = self.next_loc()?;
347 let res = self.rowbuf[ridx].get(cidx)?;
348 res
349 }
350 }
351
352 impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
353 type Error = OracleSourceError;
354
355 #[throws(OracleSourceError)]
356 fn produce(&'r mut self) -> Option<$t> {
357 let (ridx, cidx) = self.next_loc()?;
358 let res = self.rowbuf[ridx].get(cidx)?;
359 res
360 }
361 }
362 )+
363 };
364}
365
366impl_produce_text!(
367 i64,
368 f64,
369 String,
370 NaiveDate,
371 NaiveDateTime,
372 DateTime<Utc>,
373 Vec<u8>,
374);