connectorx/destinations/arrowstream/
mod.rs

1//! Destination implementation for Arrow and Polars.
2
3mod arrow_assoc;
4mod errors;
5mod funcs;
6pub mod typesystem;
7
8pub use self::errors::{ArrowDestinationError, Result};
9pub use self::typesystem::ArrowTypeSystem;
10use super::{Consume, Destination, DestinationPartition};
11use crate::constants::RECORD_BATCH_SIZE;
12use crate::data_order::DataOrder;
13use crate::typesystem::{Realize, TypeAssoc, TypeSystem};
14use anyhow::anyhow;
15use arrow::{datatypes::Schema, record_batch::RecordBatch};
16use arrow_assoc::ArrowAssoc;
17use fehler::{throw, throws};
18use funcs::{FFinishBuilder, FNewBuilder, FNewField};
19use itertools::Itertools;
20use std::{
21    any::Any,
22    sync::{
23        mpsc::{channel, Receiver, Sender},
24        Arc,
25    },
26};
27
28type Builder = Box<dyn Any + Send>;
29type Builders = Vec<Builder>;
30
31pub struct ArrowDestination {
32    schema: Vec<ArrowTypeSystem>,
33    names: Vec<String>,
34    arrow_schema: Arc<Schema>,
35    batch_size: usize,
36    sender: Option<Sender<RecordBatch>>,
37    receiver: Receiver<RecordBatch>,
38}
39
40impl Default for ArrowDestination {
41    fn default() -> Self {
42        let (tx, rx) = channel();
43        ArrowDestination {
44            schema: vec![],
45            names: vec![],
46            arrow_schema: Arc::new(Schema::empty()),
47            batch_size: RECORD_BATCH_SIZE,
48            sender: Some(tx),
49            receiver: rx,
50        }
51    }
52}
53
54impl ArrowDestination {
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    pub fn new_with_batch_size(batch_size: usize) -> Self {
60        let (tx, rx) = channel();
61        ArrowDestination {
62            schema: vec![],
63            names: vec![],
64            arrow_schema: Arc::new(Schema::empty()),
65            batch_size,
66            sender: Some(tx),
67            receiver: rx,
68        }
69    }
70}
71
72impl Destination for ArrowDestination {
73    const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::ColumnMajor, DataOrder::RowMajor];
74    type TypeSystem = ArrowTypeSystem;
75    type Partition<'a> = ArrowPartitionWriter;
76    type Error = ArrowDestinationError;
77
78    fn needs_count(&self) -> bool {
79        false
80    }
81
82    #[throws(ArrowDestinationError)]
83    fn allocate<S: AsRef<str>>(
84        &mut self,
85        _nrow: usize,
86        names: &[S],
87        schema: &[ArrowTypeSystem],
88        data_order: DataOrder,
89    ) {
90        // todo: support colmajor
91        if !matches!(data_order, DataOrder::RowMajor) {
92            throw!(crate::errors::ConnectorXError::UnsupportedDataOrder(
93                data_order
94            ))
95        }
96
97        // parse the metadata
98        self.schema = schema.to_vec();
99        self.names = names.iter().map(|n| n.as_ref().to_string()).collect();
100        let fields = self
101            .schema
102            .iter()
103            .zip_eq(&self.names)
104            .map(|(&dt, h)| Ok(Realize::<FNewField>::realize(dt)?(h.as_str())))
105            .collect::<Result<Vec<_>>>()?;
106        self.arrow_schema = Arc::new(Schema::new(fields));
107    }
108
109    #[throws(ArrowDestinationError)]
110    fn partition(&mut self, counts: usize) -> Vec<Self::Partition<'_>> {
111        let mut partitions = vec![];
112        let sender = self.sender.take().unwrap();
113        for _ in 0..counts {
114            partitions.push(ArrowPartitionWriter::new(
115                self.schema.clone(),
116                Arc::clone(&self.arrow_schema),
117                self.batch_size,
118                sender.clone(),
119            )?);
120        }
121        partitions
122        // self.sender should be freed
123    }
124
125    fn schema(&self) -> &[ArrowTypeSystem] {
126        self.schema.as_slice()
127    }
128}
129
130impl ArrowDestination {
131    #[throws(ArrowDestinationError)]
132    pub fn arrow(self) -> Vec<RecordBatch> {
133        if self.sender.is_some() {
134            // should not happen since it is dropped after partition
135            // but need to make sure here otherwise recv will be blocked forever
136            std::mem::drop(self.sender);
137        }
138        let mut data = vec![];
139        loop {
140            match self.receiver.recv() {
141                Ok(rb) => data.push(rb),
142                Err(_) => break,
143            }
144        }
145        data
146    }
147
148    #[throws(ArrowDestinationError)]
149    pub fn record_batch(&mut self) -> Option<RecordBatch> {
150        match self.receiver.recv() {
151            Ok(rb) => Some(rb),
152            Err(_) => None,
153        }
154    }
155
156    pub fn empty_batch(&self) -> RecordBatch {
157        RecordBatch::new_empty(self.arrow_schema.clone())
158    }
159
160    pub fn arrow_schema(&self) -> Arc<Schema> {
161        self.arrow_schema.clone()
162    }
163
164    pub fn names(&self) -> &[String] {
165        self.names.as_slice()
166    }
167}
168
169pub struct ArrowPartitionWriter {
170    schema: Vec<ArrowTypeSystem>,
171    builders: Option<Builders>,
172    current_row: usize,
173    current_col: usize,
174    arrow_schema: Arc<Schema>,
175    batch_size: usize,
176    sender: Option<Sender<RecordBatch>>,
177}
178
179// unsafe impl Sync for ArrowPartitionWriter {}
180
181impl ArrowPartitionWriter {
182    #[throws(ArrowDestinationError)]
183    fn new(
184        schema: Vec<ArrowTypeSystem>,
185        arrow_schema: Arc<Schema>,
186        batch_size: usize,
187        sender: Sender<RecordBatch>,
188    ) -> Self {
189        let mut pw = ArrowPartitionWriter {
190            schema,
191            builders: None,
192            current_row: 0,
193            current_col: 0,
194            arrow_schema,
195            batch_size,
196            sender: Some(sender),
197        };
198        pw.allocate()?;
199        pw
200    }
201
202    #[throws(ArrowDestinationError)]
203    fn allocate(&mut self) {
204        let builders = self
205            .schema
206            .iter()
207            .map(|dt| Ok(Realize::<FNewBuilder>::realize(*dt)?(self.batch_size)))
208            .collect::<Result<Vec<_>>>()?;
209        self.builders.replace(builders);
210    }
211
212    #[throws(ArrowDestinationError)]
213    fn flush(&mut self) {
214        let builders = self
215            .builders
216            .take()
217            .unwrap_or_else(|| panic!("arrow builder is none when flush!"));
218        let columns = builders
219            .into_iter()
220            .zip(self.schema.iter())
221            .map(|(builder, &dt)| Realize::<FFinishBuilder>::realize(dt)?(builder))
222            .collect::<std::result::Result<Vec<_>, crate::errors::ConnectorXError>>()?;
223        let rb = RecordBatch::try_new(Arc::clone(&self.arrow_schema), columns)?;
224        self.sender.as_ref().unwrap().send(rb).unwrap();
225
226        self.current_row = 0;
227        self.current_col = 0;
228    }
229}
230
231impl<'a> DestinationPartition<'a> for ArrowPartitionWriter {
232    type TypeSystem = ArrowTypeSystem;
233    type Error = ArrowDestinationError;
234
235    #[throws(ArrowDestinationError)]
236    fn finalize(&mut self) {
237        if self.builders.is_some() {
238            self.flush()?;
239        }
240        // need to release the sender so receiver knows when the stream is exhasted
241        std::mem::drop(self.sender.take());
242    }
243
244    #[throws(ArrowDestinationError)]
245    fn aquire_row(&mut self, _n: usize) -> usize {
246        self.current_row
247    }
248
249    fn ncols(&self) -> usize {
250        self.schema.len()
251    }
252}
253
254impl<'a, T> Consume<T> for ArrowPartitionWriter
255where
256    T: TypeAssoc<<Self as DestinationPartition<'a>>::TypeSystem> + ArrowAssoc + 'static,
257{
258    type Error = ArrowDestinationError;
259
260    #[throws(ArrowDestinationError)]
261    fn consume(&mut self, value: T) {
262        let col = self.current_col;
263        self.current_col = (self.current_col + 1) % self.ncols();
264        self.schema[col].check::<T>()?;
265
266        loop {
267            match &mut self.builders {
268                Some(builders) => {
269                    <T as ArrowAssoc>::append(
270                        builders[col]
271                            .downcast_mut::<T::Builder>()
272                            .ok_or_else(|| anyhow!("cannot cast arrow builder for append"))?,
273                        value,
274                    )?;
275                    break;
276                }
277                None => self.allocate()?, // allocate if builders are not initialized
278            }
279        }
280
281        // flush if exceed batch_size
282        if self.current_col == 0 {
283            self.current_row += 1;
284            if self.current_row >= self.batch_size {
285                self.flush()?;
286                self.allocate()?;
287            }
288        }
289    }
290}