1use crate::prelude::*;
2use arrow::record_batch::RecordBatch;
3use itertools::Itertools;
4use log::debug;
5use rayon::prelude::*;
6use std::marker::PhantomData;
7
8pub fn set_global_num_thread(num: usize) {
9 rayon::ThreadPoolBuilder::new()
10 .num_threads(num)
11 .build_global()
12 .unwrap();
13}
14
15pub struct ArrowBatchIter<S, TP>
17where
18 S: Source,
19 TP: Transport<
20 TSS = S::TypeSystem,
21 TSD = ArrowStreamTypeSystem,
22 S = S,
23 D = ArrowStreamDestination,
24 >,
25 <S as Source>::Partition: 'static,
26 <S as Source>::TypeSystem: 'static,
27 <TP as Transport>::Error: 'static,
28{
29 dst: ArrowStreamDestination,
30 dst_parts: Option<Vec<ArrowStreamPartitionWriter>>,
31 src_parts: Option<Vec<S::Partition>>,
32 dorder: DataOrder,
33 src_schema: Vec<S::TypeSystem>,
34 dst_schema: Vec<ArrowStreamTypeSystem>,
35 _phantom: PhantomData<TP>,
36}
37
38impl<'a, S, TP> ArrowBatchIter<S, TP>
39where
40 S: Source + 'a,
41 TP: Transport<
42 TSS = S::TypeSystem,
43 TSD = ArrowStreamTypeSystem,
44 S = S,
45 D = ArrowStreamDestination,
46 >,
47{
48 pub fn new(
49 src: S,
50 mut dst: ArrowStreamDestination,
51 origin_query: Option<String>,
52 queries: &[CXQuery<String>],
53 ) -> Result<Self, TP::Error> {
54 let dispatcher = Dispatcher::<_, _, TP>::new(src, &mut dst, queries, origin_query);
55 let (dorder, src_parts, dst_parts, src_schema, dst_schema) = dispatcher.prepare()?;
56
57 Ok(Self {
58 dst,
59 dst_parts: Some(dst_parts),
60 src_parts: Some(src_parts),
61 dorder,
62 src_schema,
63 dst_schema,
64 _phantom: PhantomData,
65 })
66 }
67
68 fn run(&mut self) {
69 let src_schema = self.src_schema.clone();
70 let dst_schema = self.dst_schema.clone();
71 let src_partitions = self.src_parts.take().unwrap();
72 let dst_partitions = self.dst_parts.take().unwrap();
73 let dorder = self.dorder;
74
75 std::thread::spawn(move || -> Result<(), TP::Error> {
76 let schemas: Vec<_> = src_schema
77 .iter()
78 .zip_eq(&dst_schema)
79 .map(|(&src_ty, &dst_ty)| (src_ty, dst_ty))
80 .collect();
81
82 debug!("Start writing");
83 dst_partitions
85 .into_par_iter()
86 .zip_eq(src_partitions)
87 .enumerate()
88 .try_for_each(|(i, (mut dst, mut src))| -> Result<(), TP::Error> {
89 let mut parser = src.parser()?;
90
91 match dorder {
92 DataOrder::RowMajor => loop {
93 let (n, is_last) = parser.fetch_next()?;
94 dst.aquire_row(n)?;
95 for _ in 0..n {
96 #[allow(clippy::needless_range_loop)]
97 for col in 0..dst.ncols() {
98 {
99 let (s1, s2) = schemas[col];
100 TP::process(s1, s2, &mut parser, &mut dst)?;
101 }
102 }
103 }
104 if is_last {
105 break;
106 }
107 },
108 DataOrder::ColumnMajor => loop {
109 let (n, is_last) = parser.fetch_next()?;
110 dst.aquire_row(n)?;
111 #[allow(clippy::needless_range_loop)]
112 for col in 0..dst.ncols() {
113 for _ in 0..n {
114 {
115 let (s1, s2) = schemas[col];
116 TP::process(s1, s2, &mut parser, &mut dst)?;
117 }
118 }
119 }
120 if is_last {
121 break;
122 }
123 },
124 }
125
126 debug!("Finalize partition {}", i);
127 dst.finalize()?;
128 debug!("Partition {} finished", i);
129 Ok(())
130 })?;
131
132 debug!("Writing finished");
133
134 Ok(())
135 });
136 }
137}
138
139impl<'a, S, TP> Iterator for ArrowBatchIter<S, TP>
140where
141 S: Source + 'a,
142 TP: Transport<
143 TSS = S::TypeSystem,
144 TSD = ArrowStreamTypeSystem,
145 S = S,
146 D = ArrowStreamDestination,
147 >,
148{
149 type Item = RecordBatch;
150 fn next(&mut self) -> Option<Self::Item> {
152 self.dst.record_batch().unwrap()
153 }
154}
155
156pub trait RecordBatchIterator {
157 fn get_schema(&self) -> (RecordBatch, &[String]);
158 fn prepare(&mut self);
159 fn next_batch(&mut self) -> Option<RecordBatch>;
160}
161
162impl<'a, S, TP> RecordBatchIterator for ArrowBatchIter<S, TP>
163where
164 S: Source + 'a,
165 TP: Transport<
166 TSS = S::TypeSystem,
167 TSD = ArrowStreamTypeSystem,
168 S = S,
169 D = ArrowStreamDestination,
170 >,
171{
172 fn get_schema(&self) -> (RecordBatch, &[String]) {
173 (self.dst.empty_batch(), self.dst.names())
174 }
175
176 fn prepare(&mut self) {
177 self.run();
178 }
179
180 fn next_batch(&mut self) -> Option<RecordBatch> {
181 self.next()
182 }
183}