connectorx/
fed_dispatcher.rs

1use crate::{prelude::*, sql::CXQuery};
2use arrow::record_batch::RecordBatch;
3use datafusion::datasource::MemTable;
4use datafusion::prelude::*;
5use fehler::throws;
6use log::debug;
7use rayon::prelude::*;
8use std::collections::HashMap;
9use std::convert::TryFrom;
10use std::sync::{mpsc::channel, Arc};
11
12#[throws(ConnectorXOutError)]
13pub fn run(
14    sql: String,
15    db_map: HashMap<String, String>,
16    j4rs_base: Option<&str>,
17    strategy: &str,
18) -> Vec<RecordBatch> {
19    debug!("federated input sql: {}", sql);
20    let mut db_conn_map: HashMap<String, FederatedDataSourceInfo> = HashMap::new();
21    for (k, v) in db_map.into_iter() {
22        db_conn_map.insert(
23            k,
24            FederatedDataSourceInfo::new_from_conn_str(
25                SourceConn::try_from(v.as_str())?,
26                false,
27                "",
28                "",
29            ),
30        );
31    }
32    let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base, strategy)?;
33
34    debug!("fetch queries from remote");
35    let (sender, receiver) = channel();
36    fed_plan.into_par_iter().enumerate().try_for_each_with(
37        sender,
38        |s, (i, p)| -> Result<(), ConnectorXOutError> {
39            match p.db_name.as_str() {
40                "LOCAL" => {
41                    s.send((p.sql, None)).expect("send error local");
42                }
43                _ => {
44                    debug!("start query {}: {}", i, p.sql);
45                    let mut queries = vec![];
46                    p.sql.split(';').for_each(|ss| {
47                        queries.push(CXQuery::naked(ss));
48                    });
49                    let source_conn = &db_conn_map[p.db_name.as_str()]
50                        .conn_str_info
51                        .as_ref()
52                        .unwrap();
53
54                    let destination = get_arrow(source_conn, None, queries.as_slice(), None)?;
55                    let rbs = destination.arrow()?;
56
57                    let provider = MemTable::try_new(rbs[0].schema(), vec![rbs])?;
58                    s.send((p.db_alias, Some(Arc::new(provider))))
59                        .expect(&format!("send error {}", i));
60                    debug!("query {} finished", i);
61                }
62            }
63            Ok(())
64        },
65    )?;
66
67    let ctx = SessionContext::new();
68    let mut alias_names: Vec<String> = vec![];
69    let mut local_sql = String::new();
70    receiver
71        .iter()
72        .try_for_each(|(alias, provider)| -> Result<(), ConnectorXOutError> {
73            match provider {
74                Some(p) => {
75                    ctx.register_table(alias.as_str(), p)?;
76                    alias_names.push(alias);
77                }
78                None => local_sql = alias,
79            }
80
81            Ok(())
82        })?;
83
84    debug!("\nexecute query final...\n{}\n", local_sql);
85    let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
86    // until datafusion fix the bug: https://github.com/apache/arrow-datafusion/issues/2147
87    for alias in alias_names {
88        local_sql = local_sql.replace(format!("\"{}\"", alias).as_str(), alias.as_str());
89    }
90
91    let df = rt.block_on(ctx.sql(local_sql.as_str()))?;
92    rt.block_on(df.collect())?
93}