Skip to content

Commit

Permalink
feat: improvements to inference engine
Browse files Browse the repository at this point in the history
feat: remove unnecessary unwraps
  • Loading branch information
louisjoecodes committed Feb 25, 2024
1 parent 21d11fb commit 7ac7f72
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 146 deletions.
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions rust/cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[package]
name = "quary"
version = "0.0.87"
version = "0.0.91"
edition = "2021"
rust-version = "1.75.0"
rust-version = "1.76.0"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
yup-oauth2 = { version = "8", default-features = false }
clap = { version = "4", features = ["derive"] }
quary-core = { path = "../core" }
dbt-converter = { path = "../dbt-converter" }
Expand All @@ -17,7 +18,7 @@ async-trait = "0.1"
indicatif = "0.17"
gcp-bigquery-client = "0.18"
prost = "0.12"
google-cloud-auth = { version = "0.13", default-features = false, features = [ "external-account", "rustls-tls"] }
google-cloud-auth = { version = "0.13", default-features = false, features = ["external-account", "rustls-tls"] }
snowflake-api = "0.6"
arrow-array = "50"
duckdb = { version = "0.10", features = ["bundled"] }
Expand Down
65 changes: 36 additions & 29 deletions rust/cli/src/databases_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use quary_core::databases::{
};
use std::fmt::Debug;
use std::sync::Arc;
use yup_oauth2::error::AuthErrorCode;

pub struct BigQuery {
project_id: String,
Expand Down Expand Up @@ -61,7 +62,13 @@ impl Authenticator for AccessTokenProviderHolder {
impl Authenticator for AccessTokenProvider {
async fn access_token(&self) -> Result<String, BQError> {
let token_source = &self.token_source;
let token = token_source.token().await.unwrap();
let token = token_source.token().await.map_err(|_| {
BQError::AuthError(yup_oauth2::error::AuthError {
error: AuthErrorCode::ExpiredToken,
error_description: None,
error_uri: None,
})
})?;
Ok(token.access_token)
}
}
Expand Down Expand Up @@ -131,7 +138,7 @@ impl BigQuery {
.list(&self.project_id, &self.dataset_id, options)
.await
.map_err(|e| format!("Failed to list tables: {}", e))?;
collected_tables.extend(tables.tables.unwrap());
collected_tables.extend(tables.tables.unwrap_or_default());
if tables.next_page_token.is_none() {
break;
}
Expand All @@ -145,8 +152,8 @@ impl BigQuery {
impl DatabaseConnection for BigQuery {
// TODO Return an iterator
async fn list_tables(&self) -> Result<Vec<TableAddress>, String> {
let collected_tables = self.get_all_table_like_things().await?;
let tables = collected_tables
self.get_all_table_like_things()
.await?
.iter()
.filter(|table| {
if let Some(kind) = &table.kind {
Expand All @@ -155,23 +162,21 @@ impl DatabaseConnection for BigQuery {
false
}
})
.map(|t| TableAddress {
full_path: format!(
"{}.{}.{}",
self.project_id,
self.dataset_id,
t.friendly_name.clone().unwrap()
),
name: t.friendly_name.clone().unwrap(),
.map(|t| {
let name = t
.friendly_name
.clone()
.ok_or("Failed to get friendly name of table".to_string())?;
Ok(TableAddress {
full_path: format!("{}.{}.{}", self.project_id, self.dataset_id, name),
name,
})
})
.collect();

Ok(tables)
.collect()
}

async fn list_views(&self) -> Result<Vec<TableAddress>, String> {
let collected_tables = self.get_all_table_like_things().await?;
let tables = collected_tables
self.get_all_table_like_things().await?
.iter()
.filter(|table| {
if let Some(kind) = &table.kind {
Expand All @@ -180,18 +185,20 @@ impl DatabaseConnection for BigQuery {
false
}
})
.map(|t| TableAddress {
full_path: format!(
"{}.{}.{}",
self.project_id,
self.dataset_id,
t.friendly_name.clone().unwrap()
),
name: t.friendly_name.clone().unwrap(),
.map(|t| {
let friendly_name = t
.friendly_name
.clone()
.ok_or("Failed to get friendly name of table".to_string())?;
Ok(TableAddress {
full_path: format!(
"{}.{}.{}",
self.project_id, self.dataset_id, friendly_name,
),
name: friendly_name,
})
})
.collect();

Ok(tables)
.collect()
}

async fn list_columns(&self, table: &str) -> Result<Vec<String>, String> {
Expand All @@ -206,7 +213,7 @@ impl DatabaseConnection for BigQuery {
)
.await
.map_err(|e| format!("Failed to get table {}: {}", table, e))?;
let fields = tables.schema.fields.unwrap();
let fields = tables.schema.fields.unwrap_or_default();
let columns = fields.iter().map(|f| f.name.clone()).collect();
Ok(columns)
}
Expand Down
12 changes: 9 additions & 3 deletions rust/cli/src/databases_duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl DatabaseConnection for DuckDB {
"SELECT table_name FROM information_schema.tables WHERE table_schema = '{}' AND type='table' ORDER BY name",
schema
))
.await?
.await?
} else {
self.query("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
.await?
Expand Down Expand Up @@ -111,14 +111,20 @@ impl DatabaseConnection for DuckDB {
}

async fn exec(&self, query: &str) -> Result<(), String> {
let conn = self.connection.lock().unwrap();
let conn = self
.connection
.lock()
.map_err(|e| format!("Failed to get connection lock: {}", e))?;
conn.execute(query, params![])
.map_err(|e| format!("Failed to execute query {}: {}", query, e))?;
return Ok(());
}

async fn query(&self, query: &str) -> Result<QueryResult, String> {
let conn = self.connection.lock().unwrap();
let conn = self
.connection
.lock()
.map_err(|e| format!("Failed to get connection lock: {}", e))?;

let mut stmt = conn
.prepare(query)
Expand Down
16 changes: 14 additions & 2 deletions rust/cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#![deny(clippy::expect_used)]
#![deny(clippy::needless_lifetimes)]
#![deny(clippy::needless_borrow)]
#![deny(clippy::useless_conversion)]
#![deny(clippy::unwrap_used)]
#![deny(unused_imports)]
#![deny(unused_import_braces)]

use crate::commands::{mode_to_test_runner, Cli, Commands, InitType};
use crate::databases_connection::{database_from_config, database_query_generator_from_config};
use clap::Parser;
Expand Down Expand Up @@ -48,7 +56,9 @@ async fn main() -> Result<(), String> {
for file in Asset::iter() {
let filename = file.as_ref();
let path = Path::new(filename);
let prefix = path.parent().expect("no parent");
let prefix = path
.parent()
.ok_or("Could not get parent directory for file in Asset::iter()")?;
if !prefix.exists() {
fs::create_dir_all(prefix).map_err(|e| e.to_string())?;
}
Expand All @@ -65,7 +75,9 @@ async fn main() -> Result<(), String> {
for file in DuckDBAsset::iter() {
let filename = file.as_ref();
let path = Path::new(filename);
let prefix = path.parent().expect("no parent");
let prefix = path.parent().ok_or(
"Could not get parent directory for file in DuckDBAsset::iter()",
)?;
if !prefix.exists() {
fs::create_dir_all(prefix).map_err(|e| e.to_string())?;
}
Expand Down
4 changes: 2 additions & 2 deletions rust/core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[package]
name = "quary-core"
version = "0.1.0"
version = "0.0.91"
authors = ["Ben King <9087625+benfdking@users.noreply.github.com>"]
edition = "2021"
rust-version = "1.75.0"
rust-version = "1.76.0"

[lib]
crate-type = ["cdylib", "rlib"]
Expand Down
32 changes: 32 additions & 0 deletions rust/core/src/automatic_branching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,4 +955,36 @@ sources:
let result = cache_view_name_to_table_name_and_hash(cache_view_name);
assert!(result.is_err());
}

#[test]
fn test_derive_hash_views_on_seed() {
let fs = quary_proto::FileSystem {
files: HashMap::from([
(
"seeds/seed_checkout_disputes.csv".to_string(),
File {
name: "seeds/seed_checkout_disputes.csv".to_string(),
contents: prost::bytes::Bytes::from("id,order_id,payment_method,amount"),
},
),
(
"quary.yaml".to_string(),
File {
name: "quary.yaml".to_string(),
contents: prost::bytes::Bytes::from(
r#"
sqliteInMemory: {}
"#
.as_bytes(),
),
},
),
]),
};
let database = DatabaseQueryGeneratorSqlite {};
let project = parse_project(&fs, &database, "").unwrap();
let graph = project_to_graph(project.clone()).unwrap();

assert!(derive_hash_views(&database, &project, &graph).is_ok());
}
}
Loading

0 comments on commit 7ac7f72

Please sign in to comment.