connectorx/sources/oracle/
mod.rs1mod errors;
2mod typesystem;
3
4use std::collections::HashMap;
5
6pub use self::errors::OracleSourceError;
7pub use self::typesystem::OracleTypeSystem;
8use crate::constants::{DB_BUFFER_SIZE, ORACLE_ARRAY_SIZE};
9use crate::{
10 data_order::DataOrder,
11 errors::ConnectorXError,
12 sources::{PartitionParser, Produce, Source, SourcePartition},
13 sql::{count_query, limit1_query_oracle, CXQuery},
14 utils::DummyBox,
15};
16use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
17use fehler::{throw, throws};
18use log::debug;
19use owning_ref::OwningHandle;
20use r2d2::{Pool, PooledConnection};
21use r2d2_oracle::oracle::ResultSet;
22use r2d2_oracle::{
23 oracle::{Connector, Row, Statement},
24 OracleConnectionManager,
25};
26use sqlparser::dialect::Dialect;
27use url::Url;
28use urlencoding::decode;
29
30type OracleManager = OracleConnectionManager;
31type OracleConn = PooledConnection<OracleManager>;
32
33#[derive(Debug)]
34pub struct OracleDialect {}
35
36impl Dialect for OracleDialect {
38 fn is_identifier_start(&self, ch: char) -> bool {
39 ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
40 }
41
42 fn is_identifier_part(&self, ch: char) -> bool {
43 ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
44 }
45}
46
47pub struct OracleSource {
48 pool: Pool<OracleManager>,
49 origin_query: Option<String>,
50 queries: Vec<CXQuery<String>>,
51 names: Vec<String>,
52 schema: Vec<OracleTypeSystem>,
53}
54
55#[throws(OracleSourceError)]
56pub fn connect_oracle(conn: &Url) -> Connector {
57 let user = decode(conn.username())?.into_owned();
58 let password = decode(conn.password().unwrap_or(""))?.into_owned();
59 let host = decode(conn.host_str().unwrap_or("localhost"))?.into_owned();
60
61 let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
62
63 let conn_str = if params.get("alias").map_or(false, |v| v == "true") {
64 host.clone()
65 } else {
66 let port = conn.port().unwrap_or(1521);
67 let path = decode(conn.path())?.into_owned();
68 format!("//{}:{}{}", host, port, path)
69 };
70
71 let mut connector = oracle::Connector::new(user.as_str(), password.as_str(), conn_str.as_str());
72 if user.is_empty() && password.is_empty() {
73 debug!("No username or password provided, assuming system auth.");
74 connector.external_auth(true);
75 }
76 connector
77}
78
79impl OracleSource {
80 #[throws(OracleSourceError)]
81 pub fn new(conn: &str, nconn: usize) -> Self {
82 let conn = Url::parse(conn)?;
83 let connector = connect_oracle(&conn)?;
84 let manager = OracleConnectionManager::from_connector(connector);
85 let pool = r2d2::Pool::builder()
86 .max_size(nconn as u32)
87 .build(manager)?;
88
89 Self {
90 pool,
91 origin_query: None,
92 queries: vec![],
93 names: vec![],
94 schema: vec![],
95 }
96 }
97}
98
99impl Source for OracleSource
100where
101 OracleSourcePartition:
102 SourcePartition<TypeSystem = OracleTypeSystem, Error = OracleSourceError>,
103{
104 const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor];
105 type Partition = OracleSourcePartition;
106 type TypeSystem = OracleTypeSystem;
107 type Error = OracleSourceError;
108
109 #[throws(OracleSourceError)]
110 fn set_data_order(&mut self, data_order: DataOrder) {
111 if !matches!(data_order, DataOrder::RowMajor) {
112 throw!(ConnectorXError::UnsupportedDataOrder(data_order));
113 }
114 }
115
116 fn set_queries<Q: ToString>(&mut self, queries: &[CXQuery<Q>]) {
117 self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect();
118 }
119
120 fn set_origin_query(&mut self, query: Option<String>) {
121 self.origin_query = query;
122 }
123
124 #[throws(OracleSourceError)]
125 fn fetch_metadata(&mut self) {
126 assert!(!self.queries.is_empty());
127
128 let conn = self.pool.get()?;
129 for (i, query) in self.queries.iter().enumerate() {
130 match conn.query(limit1_query_oracle(query)?.as_str(), &[]) {
135 Ok(rows) => {
136 let (names, types) = rows
137 .column_info()
138 .iter()
139 .map(|col| {
140 (
141 col.name().to_string(),
142 OracleTypeSystem::from(col.oracle_type()),
143 )
144 })
145 .unzip();
146 self.names = names;
147 self.schema = types;
148 return;
149 }
150 Err(e) if i == self.queries.len() - 1 => {
151 debug!("cannot get metadata for '{}': {}", query, e);
153 throw!(e);
154 }
155 Err(_) => {}
156 }
157 }
158 let iter = conn.query(self.queries[0].as_str(), &[])?;
160 let (names, types) = iter
161 .column_info()
162 .iter()
163 .map(|col| (col.name().to_string(), OracleTypeSystem::VarChar(false)))
164 .unzip();
165 self.names = names;
166 self.schema = types;
167 }
168
169 #[throws(OracleSourceError)]
170 fn result_rows(&mut self) -> Option<usize> {
171 match &self.origin_query {
172 Some(q) => {
173 let cxq = CXQuery::Naked(q.clone());
174 let conn = self.pool.get()?;
175
176 let nrows = conn
177 .query_row_as::<usize>(count_query(&cxq, &OracleDialect {})?.as_str(), &[])?;
178 Some(nrows)
179 }
180 None => None,
181 }
182 }
183
184 fn names(&self) -> Vec<String> {
185 self.names.clone()
186 }
187
188 fn schema(&self) -> Vec<Self::TypeSystem> {
189 self.schema.clone()
190 }
191
192 #[throws(OracleSourceError)]
193 fn partition(self) -> Vec<Self::Partition> {
194 let mut ret = vec![];
195 for query in self.queries {
196 let conn = self.pool.get()?;
197 ret.push(OracleSourcePartition::new(conn, &query, &self.schema));
198 }
199 ret
200 }
201}
202
203pub struct OracleSourcePartition {
204 conn: OracleConn,
205 query: CXQuery<String>,
206 schema: Vec<OracleTypeSystem>,
207 nrows: usize,
208 ncols: usize,
209}
210
211impl OracleSourcePartition {
212 pub fn new(conn: OracleConn, query: &CXQuery<String>, schema: &[OracleTypeSystem]) -> Self {
213 Self {
214 conn,
215 query: query.clone(),
216 schema: schema.to_vec(),
217 nrows: 0,
218 ncols: schema.len(),
219 }
220 }
221}
222
223impl SourcePartition for OracleSourcePartition {
224 type TypeSystem = OracleTypeSystem;
225 type Parser<'a> = OracleTextSourceParser<'a>;
226 type Error = OracleSourceError;
227
228 #[throws(OracleSourceError)]
229 fn result_rows(&mut self) {
230 self.nrows = self
231 .conn
232 .query_row_as::<usize>(count_query(&self.query, &OracleDialect {})?.as_str(), &[])?;
233 }
234
235 #[throws(OracleSourceError)]
236 fn parser(&mut self) -> Self::Parser<'_> {
237 let query = self.query.clone();
238
239 OracleTextSourceParser::new(&self.conn, query.as_str(), &self.schema)?
241 }
242
243 fn nrows(&self) -> usize {
244 self.nrows
245 }
246
247 fn ncols(&self) -> usize {
248 self.ncols
249 }
250}
251
252unsafe impl<'a> Send for OracleTextSourceParser<'a> {}
253
254pub struct OracleTextSourceParser<'a> {
255 rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>>,
256 rowbuf: Vec<Row>,
257 ncols: usize,
258 current_col: usize,
259 current_row: usize,
260 is_finished: bool,
261}
262
263impl<'a> OracleTextSourceParser<'a> {
264 #[throws(OracleSourceError)]
265 pub fn new(conn: &'a OracleConn, query: &str, schema: &[OracleTypeSystem]) -> Self {
266 let stmt = conn
267 .statement(query)
268 .prefetch_rows(ORACLE_ARRAY_SIZE)
269 .fetch_array_size(ORACLE_ARRAY_SIZE)
270 .build()?;
271 let rows: OwningHandle<Box<Statement>, DummyBox<ResultSet<'a, Row>>> =
272 OwningHandle::new_with_fn(Box::new(stmt), |stmt: *const Statement| unsafe {
273 DummyBox((*(stmt as *mut Statement)).query(&[]).unwrap())
274 });
275
276 Self {
277 rows,
278 rowbuf: Vec::with_capacity(DB_BUFFER_SIZE),
279 ncols: schema.len(),
280 current_row: 0,
281 current_col: 0,
282 is_finished: false,
283 }
284 }
285
286 #[throws(OracleSourceError)]
287 fn next_loc(&mut self) -> (usize, usize) {
288 let ret = (self.current_row, self.current_col);
289 self.current_row += (self.current_col + 1) / self.ncols;
290 self.current_col = (self.current_col + 1) % self.ncols;
291 ret
292 }
293}
294
295impl<'a> PartitionParser<'a> for OracleTextSourceParser<'a> {
296 type TypeSystem = OracleTypeSystem;
297 type Error = OracleSourceError;
298
299 #[throws(OracleSourceError)]
300 fn fetch_next(&mut self) -> (usize, bool) {
301 assert!(self.current_col == 0);
302 let remaining_rows = self.rowbuf.len() - self.current_row;
303 if remaining_rows > 0 {
304 return (remaining_rows, self.is_finished);
305 } else if self.is_finished {
306 return (0, self.is_finished);
307 }
308
309 if !self.rowbuf.is_empty() {
310 self.rowbuf.drain(..);
311 }
312 for _ in 0..DB_BUFFER_SIZE {
313 if let Some(item) = (*self.rows).next() {
314 self.rowbuf.push(item?);
315 } else {
316 self.is_finished = true;
317 break;
318 }
319 }
320 self.current_row = 0;
321 self.current_col = 0;
322 (self.rowbuf.len(), self.is_finished)
323 }
324}
325
326macro_rules! impl_produce_text {
327 ($($t: ty,)+) => {
328 $(
329 impl<'r, 'a> Produce<'r, $t> for OracleTextSourceParser<'a> {
330 type Error = OracleSourceError;
331
332 #[throws(OracleSourceError)]
333 fn produce(&'r mut self) -> $t {
334 let (ridx, cidx) = self.next_loc()?;
335 let res = self.rowbuf[ridx].get(cidx)?;
336 res
337 }
338 }
339
340 impl<'r, 'a> Produce<'r, Option<$t>> for OracleTextSourceParser<'a> {
341 type Error = OracleSourceError;
342
343 #[throws(OracleSourceError)]
344 fn produce(&'r mut self) -> Option<$t> {
345 let (ridx, cidx) = self.next_loc()?;
346 let res = self.rowbuf[ridx].get(cidx)?;
347 res
348 }
349 }
350 )+
351 };
352}
353
354impl_produce_text!(
355 i64,
356 f64,
357 String,
358 NaiveDate,
359 NaiveDateTime,
360 DateTime<Utc>,
361 Vec<u8>,
362);