From 7ac7f7210d8499cff6b996727cec760b5619d9c6 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 24 Feb 2024 21:27:21 -0800 Subject: [PATCH] feat: improvements to inference engine feat: remove unnecessary unwraps --- Cargo.lock | 9 +- rust/cli/Cargo.toml | 7 +- rust/cli/src/databases_bigquery.rs | 65 ++++++++------ rust/cli/src/databases_duckdb.rs | 12 ++- rust/cli/src/main.rs | 16 +++- rust/core/Cargo.toml | 4 +- rust/core/src/automatic_branching.rs | 32 +++++++ rust/core/src/graph.rs | 127 +++++++++------------------ rust/dbt-converter/Cargo.toml | 4 +- rust/sqlinference/Cargo.toml | 4 +- rust/sqlinference/src/infer_tests.rs | 38 +++++--- 11 files changed, 172 insertions(+), 146 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ceebd537..3a5577ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -981,7 +981,7 @@ checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" [[package]] name = "dbt-converter" -version = "0.1.0" +version = "0.0.91" dependencies = [ "quary_proto", "regex", @@ -2468,7 +2468,7 @@ dependencies = [ [[package]] name = "quary" -version = "0.0.87" +version = "0.0.91" dependencies = [ "arrow-array 50.0.0", "assert_cmd", @@ -2487,11 +2487,12 @@ dependencies = [ "sqlx", "tempfile", "tokio", + "yup-oauth2", ] [[package]] name = "quary-core" -version = "0.1.0" +version = "0.0.91" dependencies = [ "async-trait", "csv", @@ -3265,7 +3266,7 @@ dependencies = [ [[package]] name = "sqlinference" -version = "0.1.0" +version = "0.0.91" dependencies = [ "sqlparser", ] diff --git a/rust/cli/Cargo.toml b/rust/cli/Cargo.toml index 3268d664..d77e88d7 100644 --- a/rust/cli/Cargo.toml +++ b/rust/cli/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "quary" -version = "0.0.87" +version = "0.0.91" edition = "2021" -rust-version = "1.75.0" +rust-version = "1.76.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +yup-oauth2 = { version = "8", default-features = false } clap = { version = "4", features = ["derive"] } quary-core = { path = "../core" } dbt-converter = { path = "../dbt-converter" } @@ -17,7 +18,7 @@ async-trait = "0.1" indicatif = "0.17" gcp-bigquery-client = "0.18" prost = "0.12" -google-cloud-auth = { version = "0.13", default-features = false, features = [ "external-account", "rustls-tls"] } +google-cloud-auth = { version = "0.13", default-features = false, features = ["external-account", "rustls-tls"] } snowflake-api = "0.6" arrow-array = "50" duckdb = { version = "0.10", features = ["bundled"] } diff --git a/rust/cli/src/databases_bigquery.rs b/rust/cli/src/databases_bigquery.rs index f121d8a3..64c08c5b 100644 --- a/rust/cli/src/databases_bigquery.rs +++ b/rust/cli/src/databases_bigquery.rs @@ -11,6 +11,7 @@ use quary_core::databases::{ }; use std::fmt::Debug; use std::sync::Arc; +use yup_oauth2::error::AuthErrorCode; pub struct BigQuery { project_id: String, @@ -61,7 +62,13 @@ impl Authenticator for AccessTokenProviderHolder { impl Authenticator for AccessTokenProvider { async fn access_token(&self) -> Result { let token_source = &self.token_source; - let token = token_source.token().await.unwrap(); + let token = token_source.token().await.map_err(|_| { + BQError::AuthError(yup_oauth2::error::AuthError { + error: AuthErrorCode::ExpiredToken, + error_description: None, + error_uri: None, + }) + })?; Ok(token.access_token) } } @@ -131,7 +138,7 @@ impl BigQuery { .list(&self.project_id, &self.dataset_id, options) .await .map_err(|e| format!("Failed to list tables: {}", e))?; - collected_tables.extend(tables.tables.unwrap()); + collected_tables.extend(tables.tables.unwrap_or_default()); if tables.next_page_token.is_none() { break; } @@ -145,8 +152,8 @@ impl BigQuery { impl DatabaseConnection for BigQuery { // TODO Return an iterator async fn list_tables(&self) -> Result, String> { - let collected_tables = self.get_all_table_like_things().await?; - let tables = collected_tables + self.get_all_table_like_things() + .await? .iter() .filter(|table| { if let Some(kind) = &table.kind { @@ -155,23 +162,21 @@ impl DatabaseConnection for BigQuery { false } }) - .map(|t| TableAddress { - full_path: format!( - "{}.{}.{}", - self.project_id, - self.dataset_id, - t.friendly_name.clone().unwrap() - ), - name: t.friendly_name.clone().unwrap(), + .map(|t| { + let name = t + .friendly_name + .clone() + .ok_or("Failed to get friendly name of table".to_string())?; + Ok(TableAddress { + full_path: format!("{}.{}.{}", self.project_id, self.dataset_id, name), + name, + }) }) - .collect(); - - Ok(tables) + .collect() } async fn list_views(&self) -> Result, String> { - let collected_tables = self.get_all_table_like_things().await?; - let tables = collected_tables + self.get_all_table_like_things().await? .iter() .filter(|table| { if let Some(kind) = &table.kind { @@ -180,18 +185,20 @@ impl DatabaseConnection for BigQuery { false } }) - .map(|t| TableAddress { - full_path: format!( - "{}.{}.{}", - self.project_id, - self.dataset_id, - t.friendly_name.clone().unwrap() - ), - name: t.friendly_name.clone().unwrap(), + .map(|t| { + let friendly_name = t + .friendly_name + .clone() + .ok_or("Failed to get friendly name of table".to_string())?; + Ok(TableAddress { + full_path: format!( + "{}.{}.{}", + self.project_id, self.dataset_id, friendly_name, + ), + name: friendly_name, + }) }) - .collect(); - - Ok(tables) + .collect() } async fn list_columns(&self, table: &str) -> Result, String> { @@ -206,7 +213,7 @@ impl DatabaseConnection for BigQuery { ) .await .map_err(|e| format!("Failed to get table {}: {}", table, e))?; - let fields = tables.schema.fields.unwrap(); + let fields = tables.schema.fields.unwrap_or_default(); let columns = fields.iter().map(|f| f.name.clone()).collect(); Ok(columns) } diff --git a/rust/cli/src/databases_duckdb.rs b/rust/cli/src/databases_duckdb.rs index c89e8328..e1ee00f8 100644 --- a/rust/cli/src/databases_duckdb.rs +++ b/rust/cli/src/databases_duckdb.rs @@ -57,7 +57,7 @@ impl DatabaseConnection for DuckDB { "SELECT table_name FROM information_schema.tables WHERE table_schema = '{}' AND type='table' ORDER BY name", schema )) - .await? + .await? } else { self.query("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") .await? @@ -111,14 +111,20 @@ impl DatabaseConnection for DuckDB { } async fn exec(&self, query: &str) -> Result<(), String> { - let conn = self.connection.lock().unwrap(); + let conn = self + .connection + .lock() + .map_err(|e| format!("Failed to get connection lock: {}", e))?; conn.execute(query, params![]) .map_err(|e| format!("Failed to execute query {}: {}", query, e))?; return Ok(()); } async fn query(&self, query: &str) -> Result { - let conn = self.connection.lock().unwrap(); + let conn = self + .connection + .lock() + .map_err(|e| format!("Failed to get connection lock: {}", e))?; let mut stmt = conn .prepare(query) diff --git a/rust/cli/src/main.rs b/rust/cli/src/main.rs index 0f8174a2..62ed45f5 100644 --- a/rust/cli/src/main.rs +++ b/rust/cli/src/main.rs @@ -1,3 +1,11 @@ +#![deny(clippy::expect_used)] +#![deny(clippy::needless_lifetimes)] +#![deny(clippy::needless_borrow)] +#![deny(clippy::useless_conversion)] +#![deny(clippy::unwrap_used)] +#![deny(unused_imports)] +#![deny(unused_import_braces)] + use crate::commands::{mode_to_test_runner, Cli, Commands, InitType}; use crate::databases_connection::{database_from_config, database_query_generator_from_config}; use clap::Parser; @@ -48,7 +56,9 @@ async fn main() -> Result<(), String> { for file in Asset::iter() { let filename = file.as_ref(); let path = Path::new(filename); - let prefix = path.parent().expect("no parent"); + let prefix = path + .parent() + .ok_or("Could not get parent directory for file in Asset::iter()")?; if !prefix.exists() { fs::create_dir_all(prefix).map_err(|e| e.to_string())?; } @@ -65,7 +75,9 @@ async fn main() -> Result<(), String> { for file in DuckDBAsset::iter() { let filename = file.as_ref(); let path = Path::new(filename); - let prefix = path.parent().expect("no parent"); + let prefix = path.parent().ok_or( + "Could not get parent directory for file in DuckDBAsset::iter()", + )?; if !prefix.exists() { fs::create_dir_all(prefix).map_err(|e| e.to_string())?; } diff --git a/rust/core/Cargo.toml b/rust/core/Cargo.toml index 616e9455..fef3d2d8 100644 --- a/rust/core/Cargo.toml +++ b/rust/core/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "quary-core" -version = "0.1.0" +version = "0.0.91" authors = ["Ben King <9087625+benfdking@users.noreply.github.com>"] edition = "2021" -rust-version = "1.75.0" +rust-version = "1.76.0" [lib] crate-type = ["cdylib", "rlib"] diff --git a/rust/core/src/automatic_branching.rs b/rust/core/src/automatic_branching.rs index 74db3cb0..877ad3d1 100644 --- a/rust/core/src/automatic_branching.rs +++ b/rust/core/src/automatic_branching.rs @@ -955,4 +955,36 @@ sources: let result = cache_view_name_to_table_name_and_hash(cache_view_name); assert!(result.is_err()); } + + #[test] + fn test_derive_hash_views_on_seed() { + let fs = quary_proto::FileSystem { + files: HashMap::from([ + ( + "seeds/seed_checkout_disputes.csv".to_string(), + File { + name: "seeds/seed_checkout_disputes.csv".to_string(), + contents: prost::bytes::Bytes::from("id,order_id,payment_method,amount"), + }, + ), + ( + "quary.yaml".to_string(), + File { + name: "quary.yaml".to_string(), + contents: prost::bytes::Bytes::from( + r#" + sqliteInMemory: {} + "# + .as_bytes(), + ), + }, + ), + ]), + }; + let database = DatabaseQueryGeneratorSqlite {}; + let project = parse_project(&fs, &database, "").unwrap(); + let graph = project_to_graph(project.clone()).unwrap(); + + assert!(derive_hash_views(&database, &project, &graph).is_ok()); + } } diff --git a/rust/core/src/graph.rs b/rust/core/src/graph.rs index ff3d34d0..33f3df65 100644 --- a/rust/core/src/graph.rs +++ b/rust/core/src/graph.rs @@ -1,4 +1,5 @@ #![allow(clippy::unwrap_used)] + use crate::map_helpers::safe_adder_set; use crate::test_helpers::ToTest; use petgraph::algo::{is_cyclic_directed, toposort}; @@ -62,11 +63,11 @@ pub fn project_to_graph(project: Project) -> Result { for reference in test.references.clone() { if !taken.contains(&reference) { return Err(format!( - "reference to {} in model {} does not exist in reference-able objects {}", - reference, - name, - Vec::from_iter(taken).join(","), - )); + "reference to {} in model {} does not exist in reference-able objects {}", + reference, + name, + Vec::from_iter(taken).join(","), + )); }; edges.push((reference, name.clone())) } @@ -108,7 +109,7 @@ pub fn project_to_graph(project: Project) -> Result { } } - let graph = QGraph::new_from_edges(edges.clone())?; + let graph = QGraph::new_from_nodes_and_edges(taken.clone(), edges.clone())?; Ok(ProjectGraph { edges, graph }) } @@ -151,26 +152,6 @@ impl QGraph { Ok(QGraph { graph, dictionary }) } - // new_from_edges returns an error if the graph is cyclic. - fn new_from_edges(edges: Vec) -> Result { - let mut graph = Graph::::new(); - let mut dictionary = HashMap::::new(); - for (from, to) in edges { - let from_node: NodeIndex = *dictionary - .entry(from.clone()) - .or_insert_with(|| graph.add_node(from.clone())); - let to_node: NodeIndex = *dictionary - .entry(to.clone()) - .or_insert_with(|| graph.add_node(to.clone())); - - graph.add_edge(from_node, to_node, ()); - } - if is_cyclic_directed(&graph) { - return Err("graph is cyclic".to_string()); - } - Ok(QGraph { graph, dictionary }) - } - pub fn to_dot_vis(&self) -> String { format!( "{:?}", @@ -507,58 +488,6 @@ impl QGraph { mod tests { use super::*; - #[test] - fn test_quary_graph_new_from_edges() { - let tests = vec![ - ("empty", vec![], 0, false), - ("simple", vec![("a".to_string(), "b".to_string())], 2, false), - ( - "diamond", - vec![ - ("A".to_string(), "B".to_string()), - ("A".to_string(), "C".to_string()), - ("B".to_string(), "D".to_string()), - ("C".to_string(), "D".to_string()), - ], - 4, - false, - ), - ( - "cycle", - vec![ - ("A".to_string(), "B".to_string()), - ("B".to_string(), "A".to_string()), - ], - 0, - true, - ), - ]; - - for (name, edges, want_dictionary_length, want_err) in tests { - println!("test: {}", name); - - let got = QGraph::new_from_edges(edges); - - match (got, want_err) { - (Ok(got), false) => { - assert_eq!( - got.dictionary.len(), - want_dictionary_length, - "QGraph::new_from_edges() dictionary length mismatch" - ); - assert_eq!( - got.graph.node_count(), - got.dictionary.len(), - "Mismatch between dictionary length and node count" - ); - } - (Err(_), true) => {} // Expected an error and got one - (Ok(_), true) => panic!("QGraph::new_from_edges() error expected, but got Ok"), - (Err(err), false) => panic!("QGraph::new_from_edges() unexpected error: {}", err), - } - } - } - #[test] fn test_get_node_sorted() { let tests = vec![ @@ -634,7 +563,11 @@ mod tests { for (name, edges, want, want_err) in tests { println!("Running test: {}", name); - let g = QGraph::new_from_edges(edges).unwrap(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let got = match g.get_node_sorted() { Ok(got) => got, Err(e) => { @@ -668,7 +601,11 @@ mod tests { for (name, edges, node_name, want_value) in tests { println!("Running test: {}", name); - let g = QGraph::new_from_edges(edges).unwrap(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let got = g.get_node(node_name); if !want_value { @@ -683,7 +620,7 @@ mod tests { #[test] fn test_to_dot_vis() { - let tests = vec![ + let tests: Vec<(&str, Vec, &str)> = vec![ ("empty", vec![], "digraph {\n}\n"), // TODO Implement the tests below // ( @@ -706,7 +643,12 @@ mod tests { for (name, edges, want) in tests { println!("Running test: {}", name); - let g = QGraph::new_from_edges(edges).unwrap(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); + + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let got = g.to_dot_vis(); assert_eq!(got, want); @@ -767,7 +709,12 @@ mod tests { for (name, edges, search, want, edges_want_length, edges_want) in tests { println!("Running test: {}", name); - let g = QGraph::new_from_edges(edges).unwrap(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); + + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let got = g.return_sub_graph(search).unwrap(); let mut values = got @@ -825,12 +772,16 @@ mod tests { for (name, search, edges, want, expected_edges_length) in tests { println!("Running test: {}", name); - let edges = edges + let edges: Vec = edges .into_iter() .map(|(a, b)| (a.to_string(), b.to_string())) .collect(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); - let g = QGraph::new_from_edges(edges).unwrap(); + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let got = g.return_upstream_graph(search).unwrap(); @@ -860,7 +811,11 @@ mod tests { .into_iter() .map(|[a, b]| (a.to_string(), b.to_string())) .collect(); - let g = QGraph::new_from_edges(edges).unwrap(); + let nodes = edges + .iter() + .flat_map(|(a, b)| vec![a.clone(), b.clone()]) + .collect::>(); + let g = QGraph::new_from_nodes_and_edges(nodes, edges).unwrap(); let (_, got) = g.return_parent_nods_to_apply_in_order(search).unwrap(); diff --git a/rust/dbt-converter/Cargo.toml b/rust/dbt-converter/Cargo.toml index 603cd370..dd91b7a2 100644 --- a/rust/dbt-converter/Cargo.toml +++ b/rust/dbt-converter/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "dbt-converter" -version = "0.1.0" +version = "0.0.91" edition = "2021" -rust-version = "1.75.0" +rust-version = "1.76.0" [dependencies] serde_yaml = "0.9" diff --git a/rust/sqlinference/Cargo.toml b/rust/sqlinference/Cargo.toml index f0b3c7e6..c89ddaec 100644 --- a/rust/sqlinference/Cargo.toml +++ b/rust/sqlinference/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "sqlinference" -version = "0.1.0" +version = "0.0.91" edition = "2021" -rust-version = "1.75.0" +rust-version = "1.76.0" [lib] diff --git a/rust/sqlinference/src/infer_tests.rs b/rust/sqlinference/src/infer_tests.rs index 5dfcab33..e544eaa8 100644 --- a/rust/sqlinference/src/infer_tests.rs +++ b/rust/sqlinference/src/infer_tests.rs @@ -463,7 +463,7 @@ fn extract_select(query: &Query) -> Result { Ok(ExtractedSelect::Star(value.clone())) } Extracted::ZeroMap(_) => { - return Err("Do not support zero maps for wildcard".to_string()) + return Err("Do not support zero maps for wildcard".to_string()); } } } else { @@ -637,6 +637,16 @@ fn extract_select(query: &Query) -> Result { unnamed.push(alias.value.clone()); } } + Expr::Cast { .. } => { + unnamed.push(alias.value.clone()); + } + Expr::Case { .. } => { + unnamed.push(alias.value.clone()); + } + Expr::Substring { .. } => { + unnamed.push(alias.value.clone()); + } + Expr::Wildcard => {} _ => { return Err(format!( "Expected Identifier/CompoundIdentifier or Function, not {:?}", @@ -2528,7 +2538,6 @@ LEFT JOIN q.shift_last sl ("b", ("q.model_b", "b")), ], vec![], vec![], - ), ("SELECT alias_a.a AS c, alias_b.b FROM q.model_a alias_a JOIN q.model_b alias_b ON alias_a.a=alias_b.a;", vec![ @@ -2536,33 +2545,30 @@ LEFT JOIN q.shift_last sl ("b", ("q.model_b", "b")), ], vec![], vec![], - ), ("WITH a AS (SELECT b, c AS d FROM q.table_c) SELECT b, d AS e FROM a", vec![ - ("b", ("q.table_c", "b")), - ("e", ("q.table_c", "c")), + ("b", ("q.table_c", "b")), + ("e", ("q.table_c", "c")), ], vec![], vec![], - - ), + ), ("WITH a AS (SELECT b FROM q.table_c), q AS (SELECT b AS v FROM a) SELECT v AS e FROM q", vec![ - ("e", ("q.table_c", "b")), + ("e", ("q.table_c", "b")), ], vec![], vec![], - - ), - ("SELECT a FROM (SELECT a FROM q.table_a)", vec![("a", ("q.table_a", "a"))],vec![],vec![]), + ), + ("SELECT a FROM (SELECT a FROM q.table_a)", vec![("a", ("q.table_a", "a"))], vec![], vec![]), ( "SELECT c FROM (SELECT a AS c FROM q.table_a)", vec![("c", ("q.table_a", "a"))], vec![], vec![], ), - ("SELECT a AS b FROM (SELECT c AS a FROM q.table_a)",vec![("b", ("q.table_a", "c"))],vec![],vec![]), + ("SELECT a AS b FROM (SELECT c AS a FROM q.table_a)", vec![("b", ("q.table_a", "c"))], vec![], vec![]), ("SELECT e.a AS b, g.b FROM (SELECT d.c AS a FROM q.table_a d) e INNER JOIN (SELECT b FROM q.table_b) g ON g.b=e.a" - , vec![("b", ("q.table_a", "c")),("b", ("q.table_b", "b"))], vec![], vec![]), + , vec![("b", ("q.table_a", "c")), ("b", ("q.table_b", "b"))], vec![], vec![]), ("SELECT COUNT(*) AS b FROM q.table_a" , vec![], vec![], vec!["b"]), ("SELECT count(*) AS b FROM (SELECT a.b AS c FROM q.table_a a)" @@ -2573,6 +2579,12 @@ LEFT JOIN q.shift_last sl , vec![], vec![], vec!["c"]), ("WITH bc AS (SELECT b AS c FROM q.table_a a) SELECT * FROM bc" , vec![("c", ("q.table_a", "b"))], vec![], vec![]), + // TODO Be smarter about type casting + ("SELECT date::date as cost_date FROM q.table_a" + , vec![], vec!["cost_date"], vec![]), + // TODO Be smarter about casting, here could do one of + ("SELECT CASE when market != 'THING' or receive_market != 'THING' then 1 when channel = 'THING' then 0 else 0 end as caq from q.caq", + vec![], vec!["caq"], vec![]), ]; for (sql, expected_map_entries, expected_not_parseable, expected_count) in tests {