connectorx/
arrow_batch_iter.rs

1use crate::prelude::*;
2use arrow::record_batch::RecordBatch;
3use itertools::Itertools;
4use log::debug;
5use rayon::prelude::*;
6use std::marker::PhantomData;
7
8pub fn set_global_num_thread(num: usize) {
9    rayon::ThreadPoolBuilder::new()
10        .num_threads(num)
11        .build_global()
12        .unwrap();
13}
14
15/// The iterator that returns arrow in `RecordBatch`
16pub struct ArrowBatchIter<S, TP>
17where
18    S: Source,
19    TP: Transport<
20        TSS = S::TypeSystem,
21        TSD = ArrowStreamTypeSystem,
22        S = S,
23        D = ArrowStreamDestination,
24    >,
25    <S as Source>::Partition: 'static,
26    <S as Source>::TypeSystem: 'static,
27    <TP as Transport>::Error: 'static,
28{
29    dst: ArrowStreamDestination,
30    dst_parts: Option<Vec<ArrowStreamPartitionWriter>>,
31    src_parts: Option<Vec<S::Partition>>,
32    dorder: DataOrder,
33    src_schema: Vec<S::TypeSystem>,
34    dst_schema: Vec<ArrowStreamTypeSystem>,
35    _phantom: PhantomData<TP>,
36}
37
38impl<'a, S, TP> ArrowBatchIter<S, TP>
39where
40    S: Source + 'a,
41    TP: Transport<
42        TSS = S::TypeSystem,
43        TSD = ArrowStreamTypeSystem,
44        S = S,
45        D = ArrowStreamDestination,
46    >,
47{
48    pub fn new(
49        src: S,
50        mut dst: ArrowStreamDestination,
51        origin_query: Option<String>,
52        queries: &[CXQuery<String>],
53    ) -> Result<Self, TP::Error> {
54        let dispatcher = Dispatcher::<_, _, TP>::new(src, &mut dst, queries, origin_query);
55        let (dorder, src_parts, dst_parts, src_schema, dst_schema) = dispatcher.prepare()?;
56
57        Ok(Self {
58            dst,
59            dst_parts: Some(dst_parts),
60            src_parts: Some(src_parts),
61            dorder,
62            src_schema,
63            dst_schema,
64            _phantom: PhantomData,
65        })
66    }
67
68    fn run(&mut self) {
69        let src_schema = self.src_schema.clone();
70        let dst_schema = self.dst_schema.clone();
71        let src_partitions = self.src_parts.take().unwrap();
72        let dst_partitions = self.dst_parts.take().unwrap();
73        let dorder = self.dorder;
74
75        std::thread::spawn(move || -> Result<(), TP::Error> {
76            let schemas: Vec<_> = src_schema
77                .iter()
78                .zip_eq(&dst_schema)
79                .map(|(&src_ty, &dst_ty)| (src_ty, dst_ty))
80                .collect();
81
82            debug!("Start writing");
83            // parse and write
84            dst_partitions
85                .into_par_iter()
86                .zip_eq(src_partitions)
87                .enumerate()
88                .try_for_each(|(i, (mut dst, mut src))| -> Result<(), TP::Error> {
89                    let mut parser = src.parser()?;
90
91                    match dorder {
92                        DataOrder::RowMajor => loop {
93                            let (n, is_last) = parser.fetch_next()?;
94                            dst.aquire_row(n)?;
95                            for _ in 0..n {
96                                #[allow(clippy::needless_range_loop)]
97                                for col in 0..dst.ncols() {
98                                    {
99                                        let (s1, s2) = schemas[col];
100                                        TP::process(s1, s2, &mut parser, &mut dst)?;
101                                    }
102                                }
103                            }
104                            if is_last {
105                                break;
106                            }
107                        },
108                        DataOrder::ColumnMajor => loop {
109                            let (n, is_last) = parser.fetch_next()?;
110                            dst.aquire_row(n)?;
111                            #[allow(clippy::needless_range_loop)]
112                            for col in 0..dst.ncols() {
113                                for _ in 0..n {
114                                    {
115                                        let (s1, s2) = schemas[col];
116                                        TP::process(s1, s2, &mut parser, &mut dst)?;
117                                    }
118                                }
119                            }
120                            if is_last {
121                                break;
122                            }
123                        },
124                    }
125
126                    debug!("Finalize partition {}", i);
127                    dst.finalize()?;
128                    debug!("Partition {} finished", i);
129                    Ok(())
130                })?;
131
132            debug!("Writing finished");
133
134            Ok(())
135        });
136    }
137}
138
139impl<'a, S, TP> Iterator for ArrowBatchIter<S, TP>
140where
141    S: Source + 'a,
142    TP: Transport<
143        TSS = S::TypeSystem,
144        TSD = ArrowStreamTypeSystem,
145        S = S,
146        D = ArrowStreamDestination,
147    >,
148{
149    type Item = RecordBatch;
150    /// NOTE: not thread safe
151    fn next(&mut self) -> Option<Self::Item> {
152        self.dst.record_batch().unwrap()
153    }
154}
155
156pub trait RecordBatchIterator {
157    fn get_schema(&self) -> (RecordBatch, &[String]);
158    fn prepare(&mut self);
159    fn next_batch(&mut self) -> Option<RecordBatch>;
160}
161
162impl<'a, S, TP> RecordBatchIterator for ArrowBatchIter<S, TP>
163where
164    S: Source + 'a,
165    TP: Transport<
166        TSS = S::TypeSystem,
167        TSD = ArrowStreamTypeSystem,
168        S = S,
169        D = ArrowStreamDestination,
170    >,
171{
172    fn get_schema(&self) -> (RecordBatch, &[String]) {
173        (self.dst.empty_batch(), self.dst.names())
174    }
175
176    fn prepare(&mut self) {
177        self.run();
178    }
179
180    fn next_batch(&mut self) -> Option<RecordBatch> {
181        self.next()
182    }
183}