connectorx/
fed_dispatcher.rs1use crate::{prelude::*, sql::CXQuery};
2use arrow::record_batch::RecordBatch;
3use datafusion::datasource::MemTable;
4use datafusion::prelude::*;
5use fehler::throws;
6use log::debug;
7use rayon::prelude::*;
8use std::collections::HashMap;
9use std::convert::TryFrom;
10use std::sync::{mpsc::channel, Arc};
11
12#[throws(ConnectorXOutError)]
13pub fn run(
14 sql: String,
15 db_map: HashMap<String, String>,
16 j4rs_base: Option<&str>,
17 strategy: &str,
18) -> Vec<RecordBatch> {
19 debug!("federated input sql: {}", sql);
20 let mut db_conn_map: HashMap<String, FederatedDataSourceInfo> = HashMap::new();
21 for (k, v) in db_map.into_iter() {
22 db_conn_map.insert(
23 k,
24 FederatedDataSourceInfo::new_from_conn_str(
25 SourceConn::try_from(v.as_str())?,
26 false,
27 "",
28 "",
29 ),
30 );
31 }
32 let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base, strategy)?;
33
34 debug!("fetch queries from remote");
35 let (sender, receiver) = channel();
36 fed_plan.into_par_iter().enumerate().try_for_each_with(
37 sender,
38 |s, (i, p)| -> Result<(), ConnectorXOutError> {
39 match p.db_name.as_str() {
40 "LOCAL" => {
41 s.send((p.sql, None)).expect("send error local");
42 }
43 _ => {
44 debug!("start query {}: {}", i, p.sql);
45 let mut queries = vec![];
46 p.sql.split(';').for_each(|ss| {
47 queries.push(CXQuery::naked(ss));
48 });
49 let source_conn = &db_conn_map[p.db_name.as_str()]
50 .conn_str_info
51 .as_ref()
52 .unwrap();
53
54 let destination = get_arrow(source_conn, None, queries.as_slice(), None)?;
55 let rbs = destination.arrow()?;
56
57 let provider = MemTable::try_new(rbs[0].schema(), vec![rbs])?;
58 s.send((p.db_alias, Some(Arc::new(provider))))
59 .expect(&format!("send error {}", i));
60 debug!("query {} finished", i);
61 }
62 }
63 Ok(())
64 },
65 )?;
66
67 let ctx = SessionContext::new();
68 let mut alias_names: Vec<String> = vec![];
69 let mut local_sql = String::new();
70 receiver
71 .iter()
72 .try_for_each(|(alias, provider)| -> Result<(), ConnectorXOutError> {
73 match provider {
74 Some(p) => {
75 ctx.register_table(alias.as_str(), p)?;
76 alias_names.push(alias);
77 }
78 None => local_sql = alias,
79 }
80
81 Ok(())
82 })?;
83
84 debug!("\nexecute query final...\n{}\n", local_sql);
85 let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
86 for alias in alias_names {
88 local_sql = local_sql.replace(format!("\"{}\"", alias).as_str(), alias.as_str());
89 }
90
91 let df = rt.block_on(ctx.sql(local_sql.as_str()))?;
92 rt.block_on(df.collect())?
93}