use crate::sources::postgres::errors::PostgresSourceError;
use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use postgres::{config::SslMode, Config};
use postgres_openssl::MakeTlsConnector;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::PathBuf;
use url::Url;
#[derive(Clone, Debug)]
pub struct TlsConfig {
pub pg_config: Config,
pub client_cert: Option<(PathBuf, PathBuf)>,
pub root_cert: Option<PathBuf>,
}
impl TryFrom<TlsConfig> for MakeTlsConnector {
type Error = PostgresSourceError;
fn try_from(tls_config: TlsConfig) -> Result<Self, Self::Error> {
let mut builder = SslConnector::builder(SslMethod::tls_client())?;
let ssl_mode = tls_config.pg_config.get_ssl_mode();
let (verify_ca, verify_hostname) = match ssl_mode {
SslMode::Disable | SslMode::Prefer => (false, false),
SslMode::Require => match tls_config.root_cert {
Some(_) => (true, false),
None => (false, false),
},
_ => panic!("unexpected sslmode {:?}", ssl_mode),
};
if let Some((cert, key)) = tls_config.client_cert {
builder.set_certificate_file(cert, SslFiletype::PEM)?;
builder.set_private_key_file(key, SslFiletype::PEM)?;
}
if let Some(root_cert) = tls_config.root_cert {
builder.set_ca_file(root_cert)?;
}
if !verify_ca {
builder.set_verify(SslVerifyMode::NONE); }
let mut tls_connector = MakeTlsConnector::new(builder.build());
if !verify_hostname {
tls_connector.set_callback(|connect, _| {
connect.set_verify_hostname(false);
Ok(())
});
}
Ok(tls_connector)
}
}
fn strip_bad_opts(url: &Url) -> Url {
let stripped_query: Vec<(_, _)> = url
.query_pairs()
.filter(|p| match &*p.0 {
"sslkey" | "sslcert" | "sslrootcert" => false,
_ => true,
})
.collect();
let mut url2 = url.clone();
url2.set_query(None);
for pair in stripped_query {
url2.query_pairs_mut()
.append_pair(&pair.0.to_string()[..], &pair.1.to_string()[..]);
}
url2
}
pub fn rewrite_tls_args(
conn: &Url,
) -> Result<(Config, Option<MakeTlsConnector>), PostgresSourceError> {
let params: HashMap<String, String> = conn.query_pairs().into_owned().collect();
let sslcert = params.get("sslcert").map(PathBuf::from);
let sslkey = params.get("sslkey").map(PathBuf::from);
let root_cert = params.get("sslrootcert").map(PathBuf::from);
let client_cert = match (sslcert, sslkey) {
(Some(a), Some(b)) => Some((a, b)),
_ => None,
};
let stripped_url = strip_bad_opts(conn);
let pg_config: Config = stripped_url.as_str().parse().unwrap();
let tls_config = TlsConfig {
pg_config: pg_config.clone(),
client_cert,
root_cert,
};
let tls_connector = match pg_config.get_ssl_mode() {
SslMode::Disable => None,
_ => Some(MakeTlsConnector::try_from(tls_config)?),
};
Ok((pg_config, tls_connector))
}