From 8a3982b39f6dc8065ba99a0d212b3e18a5cd2d1d Mon Sep 17 00:00:00 2001 From: Flavian Desverne Date: Fri, 23 Aug 2024 15:31:29 +0200 Subject: [PATCH] feat(typedsql): support column & param nullability (#4979) --- Cargo.lock | 10 +- libs/test-cli/src/main.rs | 1 - quaint/src/connector.rs | 4 +- .../{parsed_query.rs => describe.rs} | 29 +- quaint/src/connector/mssql/native/mod.rs | 6 +- quaint/src/connector/mysql/native/mod.rs | 18 +- .../src/connector/postgres/native/explain.rs | 57 ++ quaint/src/connector/postgres/native/mod.rs | 194 ++++- quaint/src/connector/queryable.rs | 4 +- quaint/src/connector/result_set/result_row.rs | 10 +- quaint/src/connector/sqlite/native/mod.rs | 28 +- quaint/src/connector/transaction.rs | 4 +- quaint/src/pooled/manager.rs | 4 +- quaint/src/single.rs | 4 +- .../sql-query-connector/src/database/js.rs | 4 +- query-engine/driver-adapters/src/queryable.rs | 10 +- .../driver-adapters/src/transaction.rs | 6 +- .../schema-connector/src/introspect_sql.rs | 19 +- .../sql-schema-connector/src/flavour.rs | 4 +- .../sql-schema-connector/src/flavour/mssql.rs | 6 +- .../sql-schema-connector/src/flavour/mysql.rs | 6 +- .../src/flavour/mysql/connection.rs | 8 +- .../src/flavour/postgres.rs | 6 +- .../src/flavour/postgres/connection.rs | 8 +- .../src/flavour/sqlite.rs | 8 +- .../src/flavour/sqlite/connection.rs | 25 +- .../sql-schema-connector/src/lib.rs | 28 +- .../src/sql_doc_parser.rs | 404 ++++++++-- schema-engine/core/src/state.rs | 2 + .../methods/introspectSql.toml | 7 +- .../src/assertions/quaint_result_set_ext.rs | 11 + .../sql-migration-tests/src/test_api.rs | 46 +- .../tests/query_introspection/docs.rs | 61 +- .../tests/query_introspection/mysql.rs | 81 ++ .../tests/query_introspection/pg.rs | 709 +++++++++++++++++- .../tests/query_introspection/sqlite.rs | 146 +++- .../tests/query_introspection/utils.rs | 13 + 37 files changed, 1759 insertions(+), 232 deletions(-) rename quaint/src/connector/{parsed_query.rs => describe.rs} (71%) create mode 100644 quaint/src/connector/postgres/native/explain.rs diff --git a/Cargo.lock b/Cargo.lock index fc1e942f7a5..92f9486bfea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3462,7 +3462,7 @@ dependencies = [ [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" dependencies = [ "native-tls", "tokio", @@ -3473,7 +3473,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" dependencies = [ "base64 0.13.1", "byteorder", @@ -3490,7 +3490,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" dependencies = [ "bit-vec", "bytes", @@ -5795,7 +5795,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" dependencies = [ "async-trait", "byteorder", @@ -6112,7 +6112,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.8.5", + "rand 0.3.23", "static_assertions", ] diff --git a/libs/test-cli/src/main.rs b/libs/test-cli/src/main.rs index 63932ee48eb..1341e9258ab 100644 --- a/libs/test-cli/src/main.rs +++ b/libs/test-cli/src/main.rs @@ -212,7 +212,6 @@ async fn main() -> anyhow::Result<()> { .first_datasource() .load_url(|key| std::env::var(key).ok()) .unwrap(), - force: false, queries: vec![SqlQueryInput { name: "query".to_string(), source: query_str, diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index b0cf16658ea..ad07b6698f3 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -12,11 +12,11 @@ mod column_type; mod connection_info; +mod describe; pub mod external; pub mod metrics; #[cfg(native)] pub mod native; -mod parsed_query; mod queryable; mod result_set; #[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] @@ -32,8 +32,8 @@ pub use connection_info::*; #[cfg(native)] pub use native::*; +pub use describe::*; pub use external::*; -pub use parsed_query::*; pub use queryable::*; pub use transaction::*; diff --git a/quaint/src/connector/parsed_query.rs b/quaint/src/connector/describe.rs similarity index 71% rename from quaint/src/connector/parsed_query.rs rename to quaint/src/connector/describe.rs index 525a84ee336..b719d594a56 100644 --- a/quaint/src/connector/parsed_query.rs +++ b/quaint/src/connector/describe.rs @@ -3,26 +3,34 @@ use std::borrow::Cow; use super::ColumnType; #[derive(Debug)] -pub struct ParsedRawQuery { - pub parameters: Vec, - pub columns: Vec, +pub struct DescribedQuery { + pub parameters: Vec, + pub columns: Vec, + pub enum_names: Option>, +} + +impl DescribedQuery { + pub fn param_enum_names(&self) -> Vec<&str> { + self.parameters.iter().filter_map(|p| p.enum_name.as_deref()).collect() + } } #[derive(Debug)] -pub struct ParsedRawParameter { +pub struct DescribedParameter { pub name: String, pub typ: ColumnType, pub enum_name: Option, } #[derive(Debug)] -pub struct ParsedRawColumn { +pub struct DescribedColumn { pub name: String, pub typ: ColumnType, + pub nullable: bool, pub enum_name: Option, } -impl ParsedRawParameter { +impl DescribedParameter { pub fn new_named<'a>(name: impl Into>, typ: impl Into) -> Self { let name: Cow<'_, str> = name.into(); @@ -52,7 +60,7 @@ impl ParsedRawParameter { } } -impl ParsedRawColumn { +impl DescribedColumn { pub fn new_named<'a>(name: impl Into>, typ: impl Into) -> Self { let name: Cow<'_, str> = name.into(); @@ -60,6 +68,7 @@ impl ParsedRawColumn { name: name.into_owned(), typ: typ.into(), enum_name: None, + nullable: false, } } @@ -68,6 +77,7 @@ impl ParsedRawColumn { name: format!("_{idx}"), typ: typ.into(), enum_name: None, + nullable: false, } } @@ -75,4 +85,9 @@ impl ParsedRawColumn { self.enum_name = enum_name; self } + + pub fn is_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } } diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index b71635f229a..7383e503d0a 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -6,7 +6,7 @@ mod conversion; mod error; pub(crate) use crate::connector::mssql::MssqlUrl; -use crate::connector::{timeout, IsolationLevel, ParsedRawQuery, Transaction, TransactionOptions}; +use crate::connector::{timeout, DescribedQuery, IsolationLevel, Transaction, TransactionOptions}; use crate::{ ast::{Query, Value}, @@ -183,8 +183,8 @@ impl Queryable for Mssql { self.query_raw(sql, params).await } - async fn parse_raw_query(&self, _sql: &str) -> crate::Result { - unimplemented!("SQL Server support for raw query parsing is not implemented yet.") + async fn describe_query(&self, _sql: &str) -> crate::Result { + unimplemented!("SQL Server does not support describe_query yet.") } async fn execute(&self, q: Query<'_>) -> crate::Result { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 38bcc8993d4..b4b23ab94cb 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -6,7 +6,7 @@ mod conversion; mod error; pub(crate) use crate::connector::mysql::MysqlUrl; -use crate::connector::{timeout, ColumnType, IsolationLevel, ParsedRawColumn, ParsedRawParameter, ParsedRawQuery}; +use crate::connector::{timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel}; use crate::{ ast::{Query, Value}, @@ -16,6 +16,7 @@ use crate::{ }; use async_trait::async_trait; use lru_cache::LruCache; +use mysql_async::consts::ColumnFlags; use mysql_async::{ self as my, prelude::{Query as _, Queryable as _}, @@ -247,21 +248,28 @@ impl Queryable for Mysql { self.query_raw(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { + async fn describe_query(&self, sql: &str) -> crate::Result { self.prepared(sql, |stmt| async move { let columns = stmt .columns() .iter() - .map(|col| ParsedRawColumn::new_named(col.name_str(), col)) + .map(|col| { + DescribedColumn::new_named(col.name_str(), col) + .is_nullable(!col.flags().contains(ColumnFlags::NOT_NULL_FLAG)) + }) .collect(); let parameters = stmt .params() .iter() .enumerate() - .map(|(idx, col)| ParsedRawParameter::new_unnamed(idx, col)) + .map(|(idx, col)| DescribedParameter::new_unnamed(idx, col)) .collect(); - Ok(ParsedRawQuery { columns, parameters }) + Ok(DescribedQuery { + columns, + parameters, + enum_names: None, + }) }) .await } diff --git a/quaint/src/connector/postgres/native/explain.rs b/quaint/src/connector/postgres/native/explain.rs new file mode 100644 index 00000000000..6a3e8594f8d --- /dev/null +++ b/quaint/src/connector/postgres/native/explain.rs @@ -0,0 +1,57 @@ +#[derive(serde::Deserialize, Debug)] +#[serde(untagged)] +pub(crate) enum Explain { + // NOTE: the returned JSON may not contain a `plan` field, for example, with `CALL` statements: + // https://github.com/launchbadge/sqlx/issues/1449 + // + // In this case, we should just fall back to assuming all is nullable. + // + // It may also contain additional fields we don't care about, which should not break parsing: + // https://github.com/launchbadge/sqlx/issues/2587 + // https://github.com/launchbadge/sqlx/issues/2622 + Plan { + #[serde(rename = "Plan")] + plan: Plan, + }, + + // This ensures that parsing never technically fails. + // + // We don't want to specifically expect `"Utility Statement"` because there might be other cases + // and we don't care unless it contains a query plan anyway. + Other(serde::de::IgnoredAny), +} + +#[derive(serde::Deserialize, Debug)] +pub(crate) struct Plan { + #[serde(rename = "Join Type")] + pub(crate) join_type: Option, + #[serde(rename = "Parent Relationship")] + pub(crate) parent_relation: Option, + #[serde(rename = "Output")] + pub(crate) output: Option>, + #[serde(rename = "Plans")] + pub(crate) plans: Option>, +} + +pub(crate) fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec>) { + if let Some(plan_outputs) = &plan.output { + // all outputs of a Full Join must be marked nullable + // otherwise, all outputs of the inner half of an outer join must be marked nullable + if plan.join_type.as_deref() == Some("Full") || plan.parent_relation.as_deref() == Some("Inner") { + for output in plan_outputs { + if let Some(i) = outputs.iter().position(|o| o == output) { + // N.B. this may produce false positives but those don't cause runtime errors + nullables[i] = Some(true); + } + } + } + } + + if let Some(plans) = &plan.plans { + if let Some("Left") | Some("Right") = plan.join_type.as_deref() { + for plan in plans { + visit_plan(plan, outputs, nullables); + } + } + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 0c55f301733..805ba13a602 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -4,14 +4,16 @@ pub(crate) mod column_type; mod conversion; mod error; +mod explain; pub(crate) use crate::connector::postgres::url::PostgresUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{ - timeout, ColumnType, IsolationLevel, ParsedRawColumn, ParsedRawParameter, ParsedRawQuery, Transaction, + timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, }; use crate::error::NativeErrorKind; +use crate::ValueType; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -57,6 +59,8 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex, is_healthy: AtomicBool, + is_cockroachdb: bool, + is_materialize: bool, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -247,6 +251,9 @@ impl PostgreSql { let tls = MakeTlsConnector::new(tls_builder.build()?); let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + let is_cockroachdb = conn.parameter("crdb_version").is_some(); + let is_materialize = conn.parameter("mz_version").is_some(); + tokio::spawn(conn.map(|r| match r { Ok(_) => (), Err(e) => { @@ -280,6 +287,8 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + is_cockroachdb, + is_materialize, }) } @@ -353,6 +362,112 @@ impl PostgreSql { Ok(()) } } + + // All credits go to sqlx: https://github.com/launchbadge/sqlx/blob/a892ebc6e283f443145f92bbc7fce4ae44547331/sqlx-postgres/src/connection/describe.rs#L417 + pub(crate) async fn get_nullable_for_columns(&self, stmt: &Statement) -> crate::Result>> { + let columns = stmt.columns(); + + if columns.is_empty() { + return Ok(vec![]); + } + + let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull as nullable FROM (VALUES "); + let mut args = Vec::with_capacity(columns.len() * 3); + + for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() { + if !args.is_empty() { + nullable_query += ", "; + } + + nullable_query.push_str(&format!("(${}::int4, ${}::int8, ${}::int4)", bind, bind + 1, bind + 2)); + + args.push(Value::from(i as i32)); + args.push(ValueType::Int64(column.table_oid().map(i64::from)).into()); + args.push(ValueType::Int32(column.column_id().map(i32::from)).into()); + } + + nullable_query.push_str( + ") as col(idx, table_id, col_idx) \ + LEFT JOIN pg_catalog.pg_attribute \ + ON table_id IS NOT NULL \ + AND attrelid = table_id \ + AND attnum = col_idx \ + ORDER BY col.idx", + ); + + let nullable_query_result = self.query_raw(&nullable_query, &args).await?; + let mut nullables = Vec::with_capacity(nullable_query_result.len()); + + for row in nullable_query_result { + let nullable = row.at(0).and_then(|v| v.as_bool()); + + nullables.push(nullable) + } + + // If the server is CockroachDB or Materialize, skip this step (#1248). + if !self.is_cockroachdb && !self.is_materialize { + // patch up our null inference with data from EXPLAIN + let nullable_patch = self.nullables_from_explain(stmt).await?; + + for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { + *nullable = patch.or(*nullable); + } + } + + Ok(nullables) + } + + /// Infer nullability for columns of this statement using EXPLAIN VERBOSE. + /// + /// This currently only marks columns that are on the inner half of an outer join + /// and returns `None` for all others. + /// All credits go to sqlx: https://github.com/launchbadge/sqlx/blob/a892ebc6e283f443145f92bbc7fce4ae44547331/sqlx-postgres/src/connection/describe.rs#L482 + async fn nullables_from_explain(&self, stmt: &Statement) -> Result>, Error> { + use explain::{visit_plan, Explain, Plan}; + + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {}", stmt.name()); + let params_len = stmt.params().len(); + let mut comma = false; + + if params_len > 0 { + explain += "("; + + // fill the arguments list with NULL, which should theoretically be valid + for _ in 0..params_len { + if comma { + explain += ", "; + } + + explain += "NULL"; + comma = true; + } + + explain += ")"; + } + + let explain_result = self.query_raw(&explain, &[]).await?.into_single()?; + let explains = explain_result + .into_single()? + .into_json() + .map(serde_json::from_value::<[Explain; 1]>) + .transpose()?; + let explain = explains.as_ref().and_then(|x| x.first()); + + let mut nullables = Vec::new(); + + if let Some(Explain::Plan { + plan: plan @ Plan { + output: Some(ref outputs), + .. + }, + }) = explain + { + nullables.resize(outputs.len(), None); + visit_plan(plan, outputs, &mut nullables); + } + + Ok(nullables) + } } // A SearchPath connection parameter (Display-impl) for connection initialization. @@ -474,44 +589,79 @@ impl Queryable for PostgreSql { .await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { + async fn describe_query(&self, sql: &str) -> crate::Result { let stmt = self.fetch_cached(sql, &[]).await?; - let mut columns: Vec = Vec::with_capacity(stmt.columns().len()); - let mut parameters: Vec = Vec::with_capacity(stmt.params().len()); - async fn infer_type(this: &PostgreSql, ty: &PostgresType) -> crate::Result<(ColumnType, Option)> { + let mut columns: Vec = Vec::with_capacity(stmt.columns().len()); + let mut parameters: Vec = Vec::with_capacity(stmt.params().len()); + + let enums_results = self + .query_raw("SELECT oid, typname FROM pg_type WHERE typtype = 'e';", &[]) + .await?; + + fn find_enum_by_oid(enums: &ResultSet, enum_oid: u32) -> Option<&str> { + enums.iter().find_map(|row| { + let oid = row.get("oid")?.as_i64()?; + let name = row.get("typname")?.as_str()?; + + if enum_oid == u32::try_from(oid).unwrap() { + Some(name) + } else { + None + } + }) + } + + fn resolve_type(ty: &PostgresType, enums: &ResultSet) -> (ColumnType, Option) { let column_type = ColumnType::from(ty); match ty.kind() { PostgresKind::Enum => { - let enum_name = this - .query_raw("SELECT typname FROM pg_type WHERE oid = $1;", &[Value::int64(ty.oid())]) - .await? - .into_single()? - .at(0) - .expect("could not find enum name") - .to_string() - .expect("enum name is not a string"); - - Ok((column_type, Some(enum_name))) + let enum_name = find_enum_by_oid(enums, ty.oid()) + .unwrap_or_else(|| panic!("Could not find enum with oid {}", ty.oid())); + + (column_type, Some(enum_name.to_string())) } - _ => Ok((column_type, None)), + _ => (column_type, None), } } - for col in stmt.columns() { - let (typ, enum_name) = infer_type(self, col.type_()).await?; + let nullables = self.get_nullable_for_columns(&stmt).await?; + + for (idx, (col, nullable)) in stmt.columns().iter().zip(nullables).enumerate() { + let (typ, enum_name) = resolve_type(col.type_(), &enums_results); - columns.push(ParsedRawColumn::new_named(col.name(), typ).with_enum_name(enum_name)); + if col.name() == "?column?" { + let kind = ErrorKind::QueryInvalidInput(format!("Invalid column name '?column?' for index {idx}. Your SQL query must explicitly alias that column name.")); + + return Err(Error::builder(kind).build()); + } + + columns.push( + DescribedColumn::new_named(col.name(), typ) + .with_enum_name(enum_name) + // Make fields nullable by default if we can't infer nullability. + .is_nullable(nullable.unwrap_or(true)), + ); } for param in stmt.params() { - let (typ, enum_name) = infer_type(self, param).await?; + let (typ, enum_name) = resolve_type(param, &enums_results); - parameters.push(ParsedRawParameter::new_named(param.name(), typ).with_enum_name(enum_name)); + parameters.push(DescribedParameter::new_named(param.name(), typ).with_enum_name(enum_name)); } - Ok(ParsedRawQuery { columns, parameters }) + let enum_names = enums_results + .into_iter() + .filter_map(|row| row.take("typname")) + .filter_map(|v| v.to_string()) + .collect::>(); + + Ok(DescribedQuery { + columns, + parameters, + enum_names: Some(enum_names), + }) } async fn execute(&self, q: Query<'_>) -> crate::Result { diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 3d1df4457f3..5f0fd54dad6 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -1,4 +1,4 @@ -use super::{IsolationLevel, ParsedRawQuery, ResultSet, Transaction}; +use super::{DescribedQuery, IsolationLevel, ResultSet, Transaction}; use crate::ast::*; use async_trait::async_trait; @@ -58,7 +58,7 @@ pub trait Queryable: Send + Sync { async fn version(&self) -> crate::Result>; /// Prepares a statement and returns type information. - async fn parse_raw_query(&self, sql: &str) -> crate::Result; + async fn describe_query(&self, sql: &str) -> crate::Result; /// Returns false, if connection is considered to not be in a working state. fn is_healthy(&self) -> bool; diff --git a/quaint/src/connector/result_set/result_row.rs b/quaint/src/connector/result_set/result_row.rs index ae2b068a581..58c389b294d 100644 --- a/quaint/src/connector/result_set/result_row.rs +++ b/quaint/src/connector/result_set/result_row.rs @@ -62,12 +62,20 @@ impl ResultRow { } } - /// Take a value with the given column name from the row. Usage + /// Get a value with the given column name from the row. Usage /// documentation in [ResultRowRef](struct.ResultRowRef.html). pub fn get(&self, name: &str) -> Option<&Value<'static>> { self.columns.iter().position(|c| c == name).map(|idx| &self.values[idx]) } + /// Take a value with the given column name from the row. + pub fn take(mut self, name: &str) -> Option> { + self.columns + .iter() + .position(|c| c == name) + .map(|idx| self.values.remove(idx)) + } + /// Make a referring [ResultRowRef](struct.ResultRowRef.html). pub fn as_ref(&self) -> ResultRowRef { ResultRowRef { diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 6dd23138c48..abcec7410a6 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -5,8 +5,8 @@ mod column_type; mod conversion; mod error; -use crate::connector::{sqlite::params::SqliteParams, ColumnType, ParsedRawQuery}; -use crate::connector::{IsolationLevel, ParsedRawColumn, ParsedRawParameter}; +use crate::connector::IsolationLevel; +use crate::connector::{sqlite::params::SqliteParams, ColumnType, DescribedQuery}; pub use rusqlite::{params_from_iter, version as sqlite_version}; @@ -124,28 +124,8 @@ impl Queryable for Sqlite { self.query_raw(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { - let conn = self.client.lock().await; - let stmt = conn.prepare_cached(sql)?; - - let parameters = (1..=stmt.parameter_count()) - .map(|idx| match stmt.parameter_name(idx) { - Some(name) => { - // SQLite parameter names are prefixed with a colon. We remove it here so that the js doc parser can match the names. - let name = name.strip_prefix(':').unwrap_or(name); - - ParsedRawParameter::new_named(name, ColumnType::Unknown) - } - None => ParsedRawParameter::new_unnamed(idx, ColumnType::Unknown), - }) - .collect(); - let columns = stmt - .columns() - .iter() - .map(|col| ParsedRawColumn::new_named(col.name(), col)) - .collect(); - - Ok(ParsedRawQuery { columns, parameters }) + async fn describe_query(&self, _sql: &str) -> crate::Result { + unimplemented!("SQLite describe_query is implemented in the schema engine.") } async fn execute(&self, q: Query<'_>) -> crate::Result { diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index 6a843d6f852..df4084883e8 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -108,8 +108,8 @@ impl<'a> Queryable for DefaultTransaction<'a> { self.inner.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { - self.inner.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index bae2ab3eec9..7533dffcfcc 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -34,8 +34,8 @@ impl Queryable for PooledConnection { self.inner.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { - self.inner.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await } async fn execute(&self, q: ast::Query<'_>) -> crate::Result { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 0e1eda6c00c..cbf460c4150 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -209,8 +209,8 @@ impl Queryable for Quaint { self.inner.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> crate::Result { - self.inner.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await } async fn execute(&self, q: ast::Query<'_>) -> crate::Result { diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index e0e5d9e6a52..d771eb51e40 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -90,8 +90,8 @@ impl QuaintQueryable for DriverAdapter { self.connector.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> quaint::Result { - self.connector.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.connector.describe_query(sql).await } async fn execute(&self, q: Query<'_>) -> quaint::Result { diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index fd7affc8d42..4e47e9c5163 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -6,7 +6,7 @@ use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; use futures::Future; -use quaint::connector::{ExternalConnectionInfo, ExternalConnector, ParsedRawQuery}; +use quaint::connector::{DescribedQuery, ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, error::{Error, ErrorKind}, @@ -94,8 +94,8 @@ impl QuaintQueryable for JsBaseQueryable { self.query_raw(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> quaint::Result { - self.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.describe_query(sql).await } async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { @@ -264,8 +264,8 @@ impl QuaintQueryable for JsQueryable { self.inner.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> quaint::Result { - self.inner.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await } async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 60fecdbd16c..b3dd6463089 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use metrics::decrement_gauge; use quaint::{ - connector::{IsolationLevel, ParsedRawQuery, Transaction as QuaintTransaction}, + connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; @@ -86,8 +86,8 @@ impl Queryable for JsTransaction { self.inner.query_raw_typed(sql, params).await } - async fn parse_raw_query(&self, sql: &str) -> quaint::Result { - self.inner.parse_raw_query(sql).await + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await } async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { diff --git a/schema-engine/connectors/schema-connector/src/introspect_sql.rs b/schema-engine/connectors/schema-connector/src/introspect_sql.rs index 5a06393958a..06bbb34f54b 100644 --- a/schema-engine/connectors/schema-connector/src/introspect_sql.rs +++ b/schema-engine/connectors/schema-connector/src/introspect_sql.rs @@ -33,6 +33,7 @@ pub struct IntrospectSqlQueryParameterOutput { pub documentation: Option, pub name: String, pub typ: String, + pub nullable: bool, } #[allow(missing_docs)] @@ -40,13 +41,27 @@ pub struct IntrospectSqlQueryParameterOutput { pub struct IntrospectSqlQueryColumnOutput { pub name: String, pub typ: String, + pub nullable: bool, } -impl From for IntrospectSqlQueryColumnOutput { - fn from(item: quaint::connector::ParsedRawColumn) -> Self { +impl From for IntrospectSqlQueryColumnOutput { + fn from(item: quaint::connector::DescribedColumn) -> Self { + let nullable_override = parse_nullability_override(&item.name); + Self { name: item.name, typ: item.enum_name.unwrap_or_else(|| item.typ.to_string()), + nullable: nullable_override.unwrap_or(item.nullable), } } } + +fn parse_nullability_override(column_name: &str) -> Option { + if column_name.ends_with('?') { + Some(true) + } else if column_name.ends_with('!') { + Some(false) + } else { + None + } +} diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour.rs b/schema-engine/connectors/sql-schema-connector/src/flavour.rs index 344c7a3cb83..3b2e901bcd6 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour.rs @@ -175,10 +175,10 @@ pub(crate) trait SqlFlavour: self.describe_schema(namespaces) } - fn parse_raw_query<'a>( + fn describe_query<'a>( &'a mut self, sql: &'a str, - ) -> BoxFuture<'a, ConnectorResult>; + ) -> BoxFuture<'a, ConnectorResult>; fn load_migrations_table( &mut self, diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs index 2b6b7f0f110..92843aae628 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs @@ -502,11 +502,11 @@ impl SqlFlavour for MssqlFlavour { self.schema_name() } - fn parse_raw_query<'a>( + fn describe_query<'a>( &'a mut self, _sql: &str, - ) -> BoxFuture<'a, ConnectorResult> { - unimplemented!("SQL Server support for raw query parsing is not implemented yet.") + ) -> BoxFuture<'a, ConnectorResult> { + unimplemented!("SQL Server does not support describe_query yet.") } } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs index 9951eb23c7b..e169d825e9d 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs @@ -406,12 +406,12 @@ impl SqlFlavour for MysqlFlavour { self.database_name() } - fn parse_raw_query<'a>( + fn describe_query<'a>( &'a mut self, sql: &'a str, - ) -> BoxFuture<'a, ConnectorResult> { + ) -> BoxFuture<'a, ConnectorResult> { with_connection(&mut self.state, move |conn_params, circumstances, conn| { - conn.parse_raw_query(sql, &conn_params.url, circumstances) + conn.describe_query(sql, &conn_params.url, circumstances) }) } } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs index 57713761359..fd470fc298d 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs @@ -115,14 +115,14 @@ impl Connection { self.0.query_raw(sql, params).await.map_err(quaint_err(url)) } - pub(super) async fn parse_raw_query( + pub(super) async fn describe_query( &self, sql: &str, url: &MysqlUrl, circumstances: BitFlags, - ) -> ConnectorResult { - tracing::debug!(query_type = "parse_raw_query", sql); - let mut parsed = self.0.parse_raw_query(sql).await.map_err(quaint_err(url))?; + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); + let mut parsed = self.0.describe_query(sql).await.map_err(quaint_err(url))?; if circumstances.contains(super::Circumstances::IsMysql56) || circumstances.contains(super::Circumstances::IsMysql57) diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index f1fcf5545a1..dca3b89f6f2 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -221,12 +221,12 @@ impl SqlFlavour for PostgresFlavour { }) } - fn parse_raw_query<'a>( + fn describe_query<'a>( &'a mut self, sql: &'a str, - ) -> BoxFuture<'a, ConnectorResult> { + ) -> BoxFuture<'a, ConnectorResult> { with_connection(self, move |conn_params, _, conn| { - conn.parse_raw_query(sql, &conn_params.url) + conn.describe_query(sql, &conn_params.url) }) } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 5c5c3057051..3ca9b673b0a 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -146,13 +146,13 @@ impl Connection { self.0.query_raw(sql, params).await.map_err(quaint_err(url)) } - pub(super) async fn parse_raw_query( + pub(super) async fn describe_query( &self, sql: &str, url: &PostgresUrl, - ) -> ConnectorResult { - tracing::debug!(query_type = "parse_raw_query", sql); - self.0.parse_raw_query(sql).await.map_err(quaint_err(url)) + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); + self.0.describe_query(sql).await.map_err(quaint_err(url)) } pub(super) async fn apply_migration_script(&mut self, migration_name: &str, script: &str) -> ConnectorResult<()> { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs index 75727446f01..85bff2133cb 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs @@ -255,13 +255,13 @@ impl SqlFlavour for SqliteFlavour { ready(with_connection(&mut self.state, |_, conn| conn.query_raw(sql, params))) } - fn parse_raw_query<'a>( + fn describe_query<'a>( &'a mut self, sql: &'a str, - ) -> BoxFuture<'a, ConnectorResult> { - tracing::debug!(sql, query_type = "parse_raw_query"); + ) -> BoxFuture<'a, ConnectorResult> { + tracing::debug!(sql, query_type = "describe_query"); ready(with_connection(&mut self.state, |params, conn| { - conn.parse_raw_query(sql, params) + conn.describe_query(sql, params) })) } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs index 9669297774f..995a86e87c9 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs @@ -2,7 +2,7 @@ pub(crate) use quaint::connector::rusqlite; -use quaint::connector::{ColumnType, GetRow, ParsedRawColumn, ParsedRawParameter, ToColumnNames}; +use quaint::connector::{ColumnType, DescribedColumn, DescribedParameter, GetRow, ToColumnNames}; use schema_connector::{ConnectorError, ConnectorResult}; use sql_schema_describer::{sqlite as describer, DescriberErrorKind, SqlSchema}; use sqlx_core::{column::Column, type_info::TypeInfo}; @@ -75,12 +75,12 @@ impl Connection { )) } - pub(super) fn parse_raw_query( + pub(super) fn describe_query( &mut self, sql: &str, params: &super::Params, - ) -> ConnectorResult { - tracing::debug!(query_type = "parse_raw_query", sql); + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); // SQLite only provides type information for _declared_ column types. That means any expression will not contain type information. // Sqlx works around this by running an `EXPLAIN` query and inferring types by interpreting sqlite bytecode. // If you're curious, here's the code: https://github.com/launchbadge/sqlx/blob/16e3f1025ad1e106d1acff05f591b8db62d688e2/sqlx-sqlite/src/connection/explain.rs#L557 @@ -96,27 +96,32 @@ impl Connection { // SQLite parameter names are prefixed with a colon. We remove it here so that the js doc parser can match the names. let name = name.strip_prefix(':').unwrap_or(name); - ParsedRawParameter::new_named(name, ColumnType::Unknown) + DescribedParameter::new_named(name, ColumnType::Unknown) } - None => ParsedRawParameter::new_unnamed(idx, ColumnType::Unknown), + None => DescribedParameter::new_unnamed(idx, ColumnType::Unknown), }) .collect(); let columns = stmt .columns() .iter() + .zip(&describe.nullable) .enumerate() - .map(|(idx, col)| { + .map(|(idx, (col, nullable))| { let typ = match ColumnType::from(col) { // If the column type is unknown, we try to infer it from the describe. ColumnType::Unknown => describe.column(idx).to_column_type(), typ => typ, }; - ParsedRawColumn::new_named(col.name(), typ) + DescribedColumn::new_named(col.name(), typ).is_nullable(nullable.unwrap_or(true)) }) .collect(); - Ok(quaint::connector::ParsedRawQuery { columns, parameters }) + Ok(quaint::connector::DescribedQuery { + columns, + parameters, + enum_names: None, + }) } } @@ -162,7 +167,7 @@ impl ToColumnTypeExt for &SqliteColumn { "TEXT" => ColumnType::Text, "REAL" => ColumnType::Double, "BLOB" => ColumnType::Bytes, - "INTEGER" => ColumnType::Int32, + "INTEGER" => ColumnType::Int64, // Not supported by sqlx-sqlite "NUMERIC" => ColumnType::Numeric, diff --git a/schema-engine/connectors/sql-schema-connector/src/lib.rs b/schema-engine/connectors/sql-schema-connector/src/lib.rs index 51cfaeb0373..f78ac9b60fe 100644 --- a/schema-engine/connectors/sql-schema-connector/src/lib.rs +++ b/schema-engine/connectors/sql-schema-connector/src/lib.rs @@ -21,6 +21,7 @@ use enumflags2::BitFlags; use flavour::{MssqlFlavour, MysqlFlavour, PostgresFlavour, SqlFlavour, SqliteFlavour}; use migration_pair::MigrationPair; use psl::{datamodel_connector::NativeTypeInstance, parser_database::ScalarType, ValidatedSchema}; +use quaint::connector::DescribedQuery; use schema_connector::{migrations_directory::MigrationDirectory, *}; use sql_doc_parser::parse_sql_doc; use sql_migration::{DropUserDefinedType, DropView, SqlMigration, SqlMigrationStep}; @@ -361,37 +362,38 @@ impl SchemaConnector for SqlSchemaConnector { input: IntrospectSqlQueryInput, ) -> BoxFuture<'_, ConnectorResult> { Box::pin(async move { - let parsed_query = self.flavour.parse_raw_query(&input.source).await?; + let DescribedQuery { + parameters, + columns, + enum_names, + } = self.flavour.describe_query(&input.source).await?; + let enum_names = enum_names.unwrap_or_default(); let sql_source = input.source.clone(); - let parsed_doc = parse_sql_doc(&sql_source)?; + let parsed_doc = parse_sql_doc(&sql_source, enum_names.as_slice())?; - let parameters = parsed_query - .parameters + let parameters = parameters .into_iter() - .enumerate() - .map(|(idx, param)| { + .zip(1..) + .map(|(param, idx)| { let parsed_param = parsed_doc - .get_param_at(idx + 1) + .get_param_at(idx) .or_else(|| parsed_doc.get_param_by_alias(¶m.name)); IntrospectSqlQueryParameterOutput { typ: parsed_param .and_then(|p| p.typ()) - .map(ToOwned::to_owned) .unwrap_or_else(|| param.typ.to_string()), name: parsed_param .and_then(|p| p.alias()) .map(ToOwned::to_owned) .unwrap_or_else(|| param.name), documentation: parsed_param.and_then(|p| p.documentation()).map(ToOwned::to_owned), + // Params are required by default unless overridden by sql doc. + nullable: parsed_param.and_then(|p| p.nullable()).unwrap_or(false), } }) .collect(); - let columns = parsed_query - .columns - .into_iter() - .map(IntrospectSqlQueryColumnOutput::from) - .collect(); + let columns = columns.into_iter().map(IntrospectSqlQueryColumnOutput::from).collect(); Ok(IntrospectSqlQueryOutput { name: input.name, diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs index 46d8654990b..a7f9c4912d2 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs @@ -1,12 +1,21 @@ use psl::parser_database::ScalarType; +use quaint::prelude::ColumnType; use schema_connector::{ConnectorError, ConnectorResult}; +use crate::sql_renderer::IteratorJoin; + #[derive(Debug, Default)] pub(crate) struct ParsedSqlDoc<'a> { parameters: Vec>, description: Option<&'a str>, } +#[derive(Debug)] +pub enum ParsedParamType<'a> { + ColumnType(ColumnType), + Enum(&'a str), +} + impl<'a> ParsedSqlDoc<'a> { fn add_parameter(&mut self, param: ParsedParameterDoc<'a>) -> ConnectorResult<()> { if self @@ -44,7 +53,8 @@ impl<'a> ParsedSqlDoc<'a> { #[derive(Debug, Default)] pub(crate) struct ParsedParameterDoc<'a> { alias: Option<&'a str>, - typ: Option, + typ: Option>, + nullable: Option, position: Option, documentation: Option<&'a str>, } @@ -54,10 +64,14 @@ impl<'a> ParsedParameterDoc<'a> { self.alias = name; } - fn set_typ(&mut self, typ: Option) { + fn set_typ(&mut self, typ: Option>) { self.typ = typ; } + fn set_nullable(&mut self, nullable: Option) { + self.nullable = nullable; + } + fn set_position(&mut self, position: Option) { self.position = position; } @@ -67,20 +81,31 @@ impl<'a> ParsedParameterDoc<'a> { } fn is_empty(&self) -> bool { - self.alias.is_none() && self.position.is_none() && self.typ.is_none() && self.documentation.is_none() + self.alias.is_none() + && self.position.is_none() + && self.typ.is_none() + && self.documentation.is_none() + && self.nullable.is_none() } pub(crate) fn alias(&self) -> Option<&str> { self.alias } - pub(crate) fn typ(&self) -> Option<&str> { - self.typ.map(|typ| typ.as_str()) + pub(crate) fn typ(&self) -> Option { + self.typ.as_ref().map(|typ| match typ { + ParsedParamType::ColumnType(ct) => ct.to_string(), + ParsedParamType::Enum(enm) => enm.to_string(), + }) } pub(crate) fn documentation(&self) -> Option<&str> { self.documentation } + + pub(crate) fn nullable(&self) -> Option { + self.nullable + } } #[derive(Debug, Clone, Copy)] @@ -99,6 +124,10 @@ impl<'a> Input<'a> { self.0.strip_prefix(pat).map(Self) } + fn strip_suffix_char(&self, pat: char) -> Option { + self.0.strip_suffix(pat).map(Self) + } + fn starts_with(&self, pat: &str) -> bool { self.0.starts_with(pat) } @@ -131,11 +160,11 @@ impl<'a> Input<'a> { Self(self.0.trim_end()) } - fn take_until_pattern_or_eol(&self, pattern: &[char]) -> (Input<'a>, &'a str) { + fn take_until_pattern_or_eol(&self, pattern: &[char]) -> (Input<'a>, Input<'a>) { if let Some(end) = self.find(pattern) { - (self.move_from(end), self.move_to(end).inner()) + (self.move_from(end), self.move_to(end)) } else { - (self.move_to_end(), self.inner()) + (self.move_to_end(), *self) } } @@ -155,7 +184,21 @@ fn build_error(input: Input<'_>, msg: &str) -> ConnectorError { ConnectorError::from_msg(format!("SQL documentation parsing: {msg} at '{input}'.")) } -fn parse_typ_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option)> { +fn render_enum_names(enum_names: &[String]) -> String { + if enum_names.is_empty() { + String::new() + } else { + format!( + ", {enum_names}", + enum_names = enum_names.iter().map(|name| format!("'{name}'")).join(", ") + ) + } +} + +fn parse_typ_opt<'a>( + input: Input<'a>, + enum_names: &'a [String], +) -> ConnectorResult<(Input<'a>, Option>)> { if let Some(start) = input.find(&['{']) { if let Some(end) = input.find(&['}']) { let typ = input.move_between(start + 1, end); @@ -164,17 +207,29 @@ fn parse_typ_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option { - Ok((input.move_from(end + 1), Some(st))) - } - None => { - Err(build_error( - input, - &format!("invalid type: '{typ}' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal')"), - )) - } - } + let parsed_typ = ScalarType::try_from_str(typ.inner(), false) + .map(|st| match st { + ScalarType::Int => ColumnType::Int32, + ScalarType::BigInt => ColumnType::Int64, + ScalarType::Float => ColumnType::Float, + ScalarType::Boolean => ColumnType::Boolean, + ScalarType::String => ColumnType::Text, + ScalarType::DateTime => ColumnType::DateTime, + ScalarType::Json => ColumnType::Json, + ScalarType::Bytes => ColumnType::Bytes, + ScalarType::Decimal => ColumnType::Numeric, + }) + .map(ParsedParamType::ColumnType) + .or_else(|| { + enum_names.iter().any(|enum_name| *enum_name == typ.inner()) + .then(|| ParsedParamType::Enum(typ.inner())) + }) + .ok_or_else(|| build_error( + input, + &format!("invalid type: '{typ}' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal'{})", render_enum_names(enum_names)), + ))?; + + Ok((input.move_from(end + 1), Some(parsed_typ))) } else { Err(build_error(input, "missing closing bracket")) } @@ -189,7 +244,7 @@ fn parse_position_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option().map_err(|_| { + match param_pos.inner().parse::().map_err(|_| { build_error( input, &format!("invalid position. Expected a number found: {param_pos}"), @@ -203,15 +258,19 @@ fn parse_position_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option) -> ConnectorResult<(Input<'_>, Option<&'_ str>)> { +fn parse_alias_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option<&'_ str>, Option)> { if let Some((input, alias)) = input .trim_start() .strip_prefix_char(':') .map(|input| input.take_until_pattern_or_eol(&[' '])) { - Ok((input, Some(alias))) + if let Some(alias) = alias.strip_suffix_char('?') { + Ok((input, Some(alias.inner()), Some(true))) + } else { + Ok((input, Some(alias.inner()), None)) + } } else { - Ok((input, None)) + Ok((input, None, None)) } } @@ -237,17 +296,18 @@ fn validate_param(param: &ParsedParameterDoc<'_>, input: Input<'_>) -> Connector Ok(()) } -fn parse_param(param_input: Input<'_>) -> ConnectorResult> { +fn parse_param<'a>(param_input: Input<'a>, enum_names: &'a [String]) -> ConnectorResult> { let input = param_input.strip_prefix_str("@param").unwrap().trim_start(); - let (input, typ) = parse_typ_opt(input)?; + let (input, typ) = parse_typ_opt(input, enum_names)?; let (input, position) = parse_position_opt(input)?; - let (input, alias) = parse_alias_opt(input)?; + let (input, alias, nullable) = parse_alias_opt(input)?; let documentation = parse_rest(input)?; let mut param = ParsedParameterDoc::default(); param.set_typ(typ); + param.set_nullable(nullable); param.set_position(position); param.set_alias(alias); param.set_documentation(documentation); @@ -263,7 +323,7 @@ fn parse_description(input: Input<'_>) -> ConnectorResult> { parse_rest(input) } -pub(crate) fn parse_sql_doc(sql: &str) -> ConnectorResult> { +pub(crate) fn parse_sql_doc<'a>(sql: &'a str, enum_names: &'a [String]) -> ConnectorResult> { let mut parsed_sql = ParsedSqlDoc::default(); let lines = sql.lines(); @@ -278,7 +338,7 @@ pub(crate) fn parse_sql_doc(sql: &str) -> ConnectorResult> { parsed_sql.set_description(parse_description(input)?); } else if input.starts_with("@param") { parsed_sql - .add_parameter(parse_param(input)?) + .add_parameter(parse_param(input, enum_names)?) .map_err(|err| build_error(input, err.message().unwrap()))?; } } @@ -295,7 +355,7 @@ mod tests { fn parse_param_1() { use expect_test::expect; - let res = parse_param(Input("@param $1:userId")); + let res = parse_param(Input("@param $1:userId"), &[]); let expected = expect![[r#" Ok( @@ -304,6 +364,7 @@ mod tests { "userId", ), typ: None, + nullable: None, position: Some( 1, ), @@ -319,7 +380,7 @@ mod tests { fn parse_param_2() { use expect_test::expect; - let res = parse_param(Input("@param $1:userId valid user identifier")); + let res = parse_param(Input("@param $1:userId valid user identifier"), &[]); let expected = expect![[r#" Ok( @@ -328,6 +389,7 @@ mod tests { "userId", ), typ: None, + nullable: None, position: Some( 1, ), @@ -345,7 +407,7 @@ mod tests { fn parse_param_3() { use expect_test::expect; - let res = parse_param(Input("@param {Int} :userId")); + let res = parse_param(Input("@param {Int} :userId"), &[]); let expected = expect![[r#" Ok( @@ -354,8 +416,11 @@ mod tests { "userId", ), typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: None, documentation: None, }, @@ -369,7 +434,7 @@ mod tests { fn parse_param_4() { use expect_test::expect; - let res = parse_param(Input("@param {Int} $1:userId")); + let res = parse_param(Input("@param {Int} $1:userId"), &[]); let expected = expect![[r#" Ok( @@ -378,8 +443,11 @@ mod tests { "userId", ), typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -395,7 +463,7 @@ mod tests { fn parse_param_5() { use expect_test::expect; - let res = parse_param(Input("@param {Int} $1:userId valid user identifier")); + let res = parse_param(Input("@param {Int} $1:userId valid user identifier"), &[]); let expected = expect![[r#" Ok( @@ -404,8 +472,11 @@ mod tests { "userId", ), typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -423,15 +494,18 @@ mod tests { fn parse_param_6() { use expect_test::expect; - let res = parse_param(Input("@param {Int} $1 valid user identifier")); + let res = parse_param(Input("@param {Int} $1 valid user identifier"), &[]); let expected = expect![[r#" Ok( ParsedParameterDoc { alias: None, typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -449,7 +523,7 @@ mod tests { fn parse_param_7() { use expect_test::expect; - let res = parse_param(Input("@param {Int} $1f valid user identifier")); + let res = parse_param(Input("@param {Int} $1f valid user identifier"), &[]); let expected = expect![[r#" Err( @@ -473,7 +547,7 @@ mod tests { fn parse_param_8() { use expect_test::expect; - let res = parse_param(Input("@param {} valid user identifier")); + let res = parse_param(Input("@param {} valid user identifier"), &[]); let expected = expect![[r#" Err( @@ -497,7 +571,7 @@ mod tests { fn parse_param_9() { use expect_test::expect; - let res = parse_param(Input("@param {Int $1f valid user identifier")); + let res = parse_param(Input("@param {Int $1f valid user identifier"), &[]); let expected = expect![[r#" Err( @@ -521,7 +595,7 @@ mod tests { fn parse_param_10() { use expect_test::expect; - let res = parse_param(Input("@param {Int} valid user identifier $10")); + let res = parse_param(Input("@param {Int} valid user identifier $10"), &[]); let expected = expect![[r#" Err( @@ -545,7 +619,7 @@ mod tests { fn parse_param_11() { use expect_test::expect; - let res = parse_param(Input("@param ")); + let res = parse_param(Input("@param "), &[]); let expected = expect![[r#" Err( @@ -569,15 +643,18 @@ mod tests { fn parse_param_12() { use expect_test::expect; - let res = parse_param(Input("@param {Int}$1 some documentation")); + let res = parse_param(Input("@param {Int}$1 some documentation"), &[]); let expected = expect![[r#" Ok( ParsedParameterDoc { alias: None, typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -595,15 +672,18 @@ mod tests { fn parse_param_13() { use expect_test::expect; - let res = parse_param(Input("@param {Int} $1 some documentation")); + let res = parse_param(Input("@param {Int} $1 some documentation"), &[]); let expected = expect![[r#" Ok( ParsedParameterDoc { alias: None, typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -621,7 +701,7 @@ mod tests { fn parse_param_14() { use expect_test::expect; - let res = parse_param(Input("@param {Unknown} $1")); + let res = parse_param(Input("@param {Unknown} $1"), &[]); let expected = expect![[r#" Err( @@ -641,11 +721,216 @@ mod tests { expected.assert_debug_eq(&res); } + #[test] + fn parse_param_15() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias!"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias!", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_16() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_17() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias!?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias!", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_18() { + use expect_test::expect; + + let res = parse_param(Input("@param $1:alias?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: None, + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_19() { + use expect_test::expect; + + let enums = ["MyEnum".to_string()]; + let res = parse_param(Input("@param {MyEnum} $1:alias?"), &enums); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + Enum( + "MyEnum", + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_20() { + use expect_test::expect; + + let enums = ["MyEnum".to_string()]; + let res = parse_param(Input("@param {MyEnum} $12567:alias"), &enums); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + Enum( + "MyEnum", + ), + ), + nullable: None, + position: Some( + 12567, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_21() { + use expect_test::expect; + + let enums = ["MyEnum".to_string(), "MyEnum2".to_string()]; + let res = parse_param(Input("@param {UnknownTyp} $12567:alias"), &enums); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + #[test] fn parse_sql_1() { use expect_test::expect; - let res = parse_sql_doc("-- @param {Int} $1 some documentation "); + let res = parse_sql_doc("-- @param {Int} $1 some documentation ", &[]); let expected = expect![[r#" Ok( @@ -654,8 +939,11 @@ mod tests { ParsedParameterDoc { alias: None, typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -683,6 +971,7 @@ mod tests { SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], ); let expected = expect![[r#" @@ -694,8 +983,11 @@ mod tests { "userId", ), typ: Some( - Int, + ColumnType( + Int32, + ), ), + nullable: None, position: Some( 1, ), @@ -708,8 +1000,11 @@ mod tests { "parentId", ), typ: Some( - String, + ColumnType( + Text, + ), ), + nullable: None, position: Some( 2, ), @@ -737,6 +1032,7 @@ mod tests { -- @param {Int} $1:userId valid user identifier -- @param {String} $1:parentId valid parent identifier SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], ); let expected = expect![[r#" @@ -766,6 +1062,7 @@ SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, -- @param {Int} $1:userId valid user identifier -- @param {String} $2:userId valid parent identifier SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], ); let expected = expect![[r#" @@ -796,6 +1093,7 @@ SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, * Unhandled multi-line comment */ SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], ); let expected = expect![[r#" diff --git a/schema-engine/core/src/state.rs b/schema-engine/core/src/state.rs index 9b60ccbf97f..b0ae38d95e4 100644 --- a/schema-engine/core/src/state.rs +++ b/schema-engine/core/src/state.rs @@ -431,6 +431,7 @@ impl GenericApi for EngineState { name: p.name, typ: p.typ, documentation: p.documentation, + nullable: p.nullable, }) .collect(), result_columns: q @@ -439,6 +440,7 @@ impl GenericApi for EngineState { .map(|c| SqlQueryColumnOutput { name: c.name, typ: c.typ, + nullable: c.nullable, }) .collect(), }) diff --git a/schema-engine/json-rpc-api-build/methods/introspectSql.toml b/schema-engine/json-rpc-api-build/methods/introspectSql.toml index e7ea73897be..19eb888da32 100644 --- a/schema-engine/json-rpc-api-build/methods/introspectSql.toml +++ b/schema-engine/json-rpc-api-build/methods/introspectSql.toml @@ -15,9 +15,6 @@ shape = "string" shape = "sqlQueryInput" isList = true -[recordShapes.introspectSqlParams.fields.force] -shape = "bool" - # Result [recordShapes.introspectSqlResult] @@ -58,10 +55,14 @@ shape = "string" [recordShapes.sqlQueryParameterOutput.fields.documentation] isNullable = true shape = "string" +[recordShapes.sqlQueryParameterOutput.fields.nullable] +shape = "bool" [recordShapes.sqlQueryColumnOutput] [recordShapes.sqlQueryColumnOutput.fields.name] shape = "string" [recordShapes.sqlQueryColumnOutput.fields.typ] shape = "string" +[recordShapes.sqlQueryColumnOutput.fields.nullable] +shape = "bool" diff --git a/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs b/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs index ed8058e7d1d..ad0e01a7fe2 100644 --- a/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs +++ b/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs @@ -113,6 +113,17 @@ impl<'a> RowAssertion<'a> { self } + pub fn assert_bigint_value(self, column_name: &str, expected_value: i64) -> Self { + let actual_value = self.0.get(column_name).and_then(|col: &Value<'_>| (*col).as_i64()); + + assert!( + actual_value == Some(expected_value), + "Value assertion failed for {column_name}. Expected: {expected_value:?}, got: {actual_value:?}", + ); + + self + } + pub fn assert_bytes_value(self, column_name: &str, expected_value: &[u8]) -> Self { let actual_value = self.0.get(column_name).and_then(|col: &Value<'_>| (*col).as_bytes()); diff --git a/schema-engine/sql-migration-tests/src/test_api.rs b/schema-engine/sql-migration-tests/src/test_api.rs index 91f8e2e3f0e..6d67bbd9849 100644 --- a/schema-engine/sql-migration-tests/src/test_api.rs +++ b/schema-engine/sql-migration-tests/src/test_api.rs @@ -166,29 +166,51 @@ impl TestApi { IntrospectSql::new(&mut self.connector, name, sanitized) } + // Replaces `?` with the appropriate positional parameter syntax for the current database. pub fn sanitize_sql(&self, sql: &str) -> String { let mut counter = 1; - let mut sql = sql.to_string(); if self.is_mysql() || self.is_mariadb() || self.is_sqlite() { - return sql; + return sql.to_string(); } - while let Some(idx) = sql.find('?') { - let replacer = if self.is_postgres() || self.is_cockroach() { - format!("${}", counter) - } else if self.is_mssql() { - format!("@P{}", counter) + let mut out = String::with_capacity(sql.len()); + let mut lines = sql.lines().peekable(); + + while let Some(line) = lines.next() { + // Avoid replacing query params in comments + if line.trim_start().starts_with("--") { + out.push_str(line); + + if lines.peek().is_some() { + out.push('\n'); + } } else { - unimplemented!() - }; + let mut line = line.to_string(); + + while let Some(idx) = line.find('?') { + let replacer = if self.is_postgres() || self.is_cockroach() { + format!("${}", counter) + } else if self.is_mssql() { + format!("@P{}", counter) + } else { + unimplemented!() + }; - sql.replace_range(idx..idx + 1, &replacer); + line.replace_range(idx..idx + 1, &replacer); - counter += 1; + counter += 1; + } + + out.push_str(&line); + + if lines.peek().is_some() { + out.push('\n'); + } + } } - sql + out } /// Returns true only when testing on MSSQL. diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs index a5a933a23af..e7fe7bc630c 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs @@ -8,7 +8,7 @@ fn parses_doc_complex(api: TestApi) { let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "\n -- @description some fancy query\n -- @param {Int} $1:myInt some integer\n -- @param {String}$2:myString some string\n SELECT int FROM model WHERE int = $1 and string = $2;\n ", + source: "\n -- @description some fancy query\n -- @param {Int} $1:myInt some integer\n -- @param {String}$2:myString? some string\n SELECT int FROM model WHERE int = $1 and string = $2;\n ", documentation: Some( "some fancy query", ), @@ -18,20 +18,23 @@ fn parses_doc_complex(api: TestApi) { "some integer", ), name: "myInt", - typ: "Int", + typ: "int", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: Some( "some string", ), name: "myString", - typ: "String", + typ: "string", + nullable: true, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -40,7 +43,7 @@ fn parses_doc_complex(api: TestApi) { let sql = r#" -- @description some fancy query -- @param {Int} $1:myInt some integer - -- @param {String}$2:myString some string + -- @param {String}$2:myString? some string SELECT int FROM model WHERE int = ? and string = ?; "#; @@ -62,18 +65,21 @@ fn parses_doc_no_position(api: TestApi) { "some integer", ), name: "myInt", - typ: "String", + typ: "string", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_2", typ: "unknown", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -101,19 +107,22 @@ fn parses_doc_no_alias(api: TestApi) { documentation: None, name: "int4", typ: "int", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: Some( "some string", ), name: "text", - typ: "String", + typ: "string", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -127,6 +136,46 @@ fn parses_doc_no_alias(api: TestApi) { api.introspect_sql("test_1", sql).send_sync().expect_result(expected) } +#[test_connector(tags(Postgres))] +fn parses_doc_enum_name(api: TestApi) { + api.schema_push(ENUM_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\n -- @param {MyFancyEnum} $1\n SELECT * FROM model WHERE id = $1;\n ", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "MyFancyEnum", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "enum", + typ: "MyFancyEnum", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @param {MyFancyEnum} $1 + SELECT * FROM model WHERE id = ?; + "#; + + api.introspect_sql("test_1", sql).send_sync().expect_result(expected) +} + #[test_connector(tags(Postgres))] fn invalid_position_fails(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs b/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs index 2112112254b..7b9df66d51c 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs @@ -22,36 +22,43 @@ fn insert_mysql(api: TestApi) { documentation: None, name: "_0", typ: "bigint", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_1", typ: "string", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_2", typ: "bigint", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_3", typ: "double", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_4", typ: "bytes", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_5", typ: "bigint", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_6", typ: "datetime", + nullable: false, }, ], result_columns: [], @@ -82,30 +89,97 @@ fn select_mysql(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "string", typ: "string", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bigint", typ: "bigint", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "float", typ: "double", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bytes", typ: "bytes", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bool", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "dt", typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Mysql, Mariadb))] +fn select_nullable_mysql(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, }, ], } @@ -128,12 +202,14 @@ fn empty_result(api: TestApi) { documentation: None, name: "_0", typ: "bigint", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -158,6 +234,7 @@ fn unnamed_expr(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "1 + 1", typ: "bigint", + nullable: false, }, ], } @@ -182,6 +259,7 @@ fn named_expr(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "add", typ: "bigint", + nullable: false, }, ], } @@ -206,6 +284,7 @@ fn mixed_named_expr(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "add", typ: "bigint", + nullable: false, }, ], } @@ -230,6 +309,7 @@ fn mixed_unnamed_expr(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "`int` + 1", typ: "bigint", + nullable: false, }, ], } @@ -254,6 +334,7 @@ fn mixed_expr_cast(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "test", typ: "string", + nullable: true, }, ], } diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs b/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs index eb1315ff0b5..b4b59251b9b 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs @@ -24,66 +24,179 @@ mod common { documentation: None, name: "int4", typ: "int", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "text", typ: "string", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "int8", typ: "bigint", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "float8", typ: "double", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "bytea", typ: "bytes", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "bool", typ: "bool", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "timestamp", typ: "datetime", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "string", typ: "string", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bigint", typ: "bigint", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "float", typ: "double", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bytes", typ: "bytes", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bool", typ: "bool", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "dt", typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); + } + + #[test_connector(tags(Postgres))] + fn insert_nullable(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let query = "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING int, string, bigint, float, bytes, bool, dt;"; + let res = api.introspect_sql("test_1", query).send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING int, string, bigint, float, bytes, bool, dt;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "text", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int8", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "float8", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bytea", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bool", + typ: "bool", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "timestamp", + typ: "datetime", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, }, ], } @@ -106,12 +219,14 @@ mod common { documentation: None, name: "int4", typ: "int", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -136,21 +251,25 @@ mod common { documentation: None, name: "int4", typ: "int", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "MyFancyEnum", typ: "string", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "id", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "enum", typ: "MyFancyEnum", + nullable: false, }, ], } @@ -182,47 +301,329 @@ mod postgres { let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "SELECT 1 + 1 as add;", + source: "SELECT 1 + 1 as \"add\";", documentation: None, parameters: [], result_columns: [ IntrospectSqlQueryColumnOutput { name: "add", typ: "int", + nullable: true, }, ], } "#]]; - api.introspect_sql("test_1", "SELECT 1 + 1 as add;") + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") .send_sync() .expect_result(expected) } #[test_connector(tags(Postgres), exclude(CockroachDb))] - fn unnamed_expr(api: TestApi) { + fn mixed_named_expr(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "SELECT 1 + 1;", + source: "SELECT \"int\" + 1 as \"add\" FROM \"model\";", documentation: None, parameters: [], result_columns: [ IntrospectSqlQueryColumnOutput { - name: "?column?", + name: "add", typ: "int", + nullable: true, }, ], } "#]]; - api.introspect_sql("test_1", "SELECT 1 + 1;") + api.introspect_sql("test_1", "SELECT \"int\" + 1 as \"add\" FROM \"model\";") .send_sync() .expect_result(expected) } + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn mixed_unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name.", + ), + source: None, + context: SpanTrace [], + } + Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name. + + "#]]; + + expected.assert_debug_eq( + &api.introspect_sql("test_1", "SELECT \"int\" + 1 FROM \"model\";") + .send_unwrap_err(), + ); + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn mixed_expr_cast(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(\"int\" + 1 as int) FROM model;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int4", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(\"int\" + 1 as int) FROM model;") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn subquery(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql( + "test_1", + "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + ) + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn left_join(api: TestApi) { + api.schema_push(RELATION_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "parentid", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "parentnullable", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "childid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "childnullable", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id") + .send_sync() + .expect_result(expected) + } + + // test nullability inference for various joins + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn outer_join(api: TestApi) { + api.schema_push( + "model products { + product_no Int @id + name String? + } + + model tweet { + id Int @id @default(autoincrement()) + text String + }", + ) + .send() + .assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "select tweet.id from (values (null)) vals(val) inner join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // inner join, nullability should not be overridden + api.introspect_sql( + "test_1", + "select tweet.id from (values (null)) vals(val) inner join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_2", + source: "select tweet.id from (values (null)) vals(val) left join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // tweet.id is marked NOT NULL but it's brought in from a left-join here + // which should make it nullable + api.introspect_sql( + "test_2", + "select tweet.id from (values (null)) vals(val) left join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_3", + source: "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // make sure we don't mis-infer for the outer half of the join + api.introspect_sql( + "test_3", + "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_4", + source: "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_4", + "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_5", + source: "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_5", + "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + } + macro_rules! test_native_types_pg { ( $($test_name:ident($nt:expr) => $ct:ident,)* @@ -285,53 +686,335 @@ mod crdb { use super::*; #[test_connector(tags(CockroachDb))] - fn unnamed_expr_crdb(api: TestApi) { + fn named_expr(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "SELECT 1 + 1;", + source: "SELECT 1 + 1 as \"add\";", documentation: None, parameters: [], result_columns: [ IntrospectSqlQueryColumnOutput { - name: "?column?", + name: "add", typ: "bigint", + nullable: true, }, ], } "#]]; - api.introspect_sql("test_1", "SELECT 1 + 1;") + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") .send_sync() .expect_result(expected) } #[test_connector(tags(CockroachDb))] - fn named_expr_crdb(api: TestApi) { + fn mixed_named_expr(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "SELECT 1 + 1 as add;", + source: "SELECT \"int\" + 1 as \"add\" FROM \"model\";", documentation: None, parameters: [], result_columns: [ IntrospectSqlQueryColumnOutput { name: "add", typ: "bigint", + nullable: true, }, ], } "#]]; - api.introspect_sql("test_1", "SELECT 1 + 1 as add;") + api.introspect_sql("test_1", "SELECT \"int\" + 1 as \"add\" FROM \"model\";") .send_sync() .expect_result(expected) } + #[test_connector(tags(CockroachDb))] + fn mixed_unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name.", + ), + source: None, + context: SpanTrace [], + } + Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name. + + "#]]; + + expected.assert_debug_eq( + &api.introspect_sql("test_1", "SELECT \"int\" + 1 FROM \"model\";") + .send_unwrap_err(), + ); + } + + #[test_connector(tags(CockroachDb))] + fn mixed_expr_cast(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(\"int\" + 1 as int) FROM model;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int8", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(\"int\" + 1 as int) FROM model;") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn subquery(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql( + "test_1", + "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + ) + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn left_join(api: TestApi) { + api.schema_push(RELATION_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "parentid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "parentnullable", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "childid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "childnullable", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id") + .send_sync() + .expect_result(expected) + } + + // test nullability inference for various joins + #[test_connector(tags(CockroachDb))] + fn outer_join(api: TestApi) { + api.schema_push( + "model products { + product_no Int @id + name String? + } + + model tweet { + id Int @id @default(autoincrement()) + text String + }", + ) + .send() + .assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "select tweet.id from (values (null)) vals(val) inner join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // inner join, nullability should not be overridden + api.introspect_sql( + "test_1", + "select tweet.id from (values (null)) vals(val) inner join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_2", + source: "select tweet.id from (values (null)) vals(val) left join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // tweet.id is marked NOT NULL but it's brought in from a left-join here + // which should make it nullable + api.introspect_sql( + "test_2", + "select tweet.id from (values (null)) vals(val) left join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_3", + source: "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // make sure we don't mis-infer for the outer half of the join + api.introspect_sql( + "test_3", + "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_4", + source: "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_4", + "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_5", + source: "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_5", + "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + } + macro_rules! test_native_types_crdb { ( $($test_name:ident($nt:expr) => $ct:ident,)* diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs b/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs index f5ebde73637..d018202da39 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs @@ -21,36 +21,43 @@ fn insert_sqlite(api: TestApi) { documentation: None, name: "_1", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_2", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_3", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_4", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_5", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_6", typ: "unknown", + nullable: false, }, IntrospectSqlQueryParameterOutput { documentation: None, name: "_7", typ: "unknown", + nullable: false, }, ], result_columns: [], @@ -81,30 +88,97 @@ fn select_sqlite(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "string", typ: "string", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bigint", typ: "bigint", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "float", typ: "double", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bytes", typ: "bytes", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "bool", typ: "bool", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "dt", typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Sqlite))] +fn select_nullable_sqlite(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, }, ], } @@ -127,12 +201,14 @@ fn empty_result(api: TestApi) { documentation: None, name: "_1", typ: "unknown", + nullable: false, }, ], result_columns: [ IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -156,7 +232,8 @@ fn unnamed_expr_int(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "1 + 1", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -164,7 +241,10 @@ fn unnamed_expr_int(api: TestApi) { api.introspect_sql("test_1", "SELECT 1 + 1;") .send_sync() - .expect_result(expected) + .expect_result(expected); + + api.query_raw("SELECT 1 + 1;", &[]) + .assert_single_row(|row| row.assert_bigint_value("1 + 1", 2)); } #[test_connector(tags(Sqlite))] @@ -180,7 +260,8 @@ fn named_expr_int(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "add", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -188,7 +269,38 @@ fn named_expr_int(api: TestApi) { api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") .send_sync() - .expect_result(expected) + .expect_result(expected); + + api.query_raw("SELECT 1 + 1 as \"add\";", &[]) + .assert_single_row(|row| row.assert_bigint_value("add", 2)); +} + +#[test_connector(tags(Sqlite))] +fn named_expr_int_optional(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as `add?`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add?", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as `add?`;") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1 + 1 as \"add?\";", &[]) + .assert_single_row(|row| row.assert_bigint_value("add?", 2)); } #[test_connector(tags(Sqlite))] @@ -204,7 +316,8 @@ fn mixed_named_expr_int(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "add", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -228,7 +341,8 @@ fn mixed_unnamed_expr_int(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "`int` + 1", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -252,7 +366,8 @@ fn mixed_expr_cast_int(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "CAST(`int` + 1 as int)", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -277,6 +392,7 @@ fn unnamed_expr_string(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "'hello world'", typ: "string", + nullable: false, }, ], } @@ -303,11 +419,13 @@ fn unnamed_expr_bool(api: TestApi) { result_columns: [ IntrospectSqlQueryColumnOutput { name: "1=1", - typ: "int", + typ: "bigint", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "1=0", - typ: "int", + typ: "bigint", + nullable: false, }, ], } @@ -335,14 +453,17 @@ fn unnamed_expr_real(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "1.2", typ: "double", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "2.34567891023", typ: "double", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "round(2.345)", typ: "double", + nullable: true, }, ], } @@ -374,6 +495,7 @@ fn unnamed_expr_blob(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "unhex('537475666673')", typ: "bytes", + nullable: true, }, ], } @@ -401,6 +523,7 @@ fn unnamed_expr_date(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "date('2025-05-29 14:16:00')", typ: "string", + nullable: true, }, ], } @@ -428,6 +551,7 @@ fn unnamed_expr_time(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "time('2025-05-29 14:16:00')", typ: "string", + nullable: true, }, ], } @@ -455,6 +579,7 @@ fn unnamed_expr_datetime(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "datetime('2025-05-29 14:16:00')", typ: "string", + nullable: true, }, ], } @@ -482,6 +607,7 @@ fn subquery(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "int", typ: "int", + nullable: false, }, ], } @@ -506,10 +632,12 @@ fn left_join(api: TestApi) { IntrospectSqlQueryColumnOutput { name: "parentId", typ: "int", + nullable: false, }, IntrospectSqlQueryColumnOutput { name: "childId", typ: "int", + nullable: false, }, ], } diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs b/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs index 4f9295a8ffb..f1666ee112a 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs @@ -14,6 +14,17 @@ model model { dt DateTime }"#; +pub(crate) const SIMPLE_NULLABLE_SCHEMA: &str = r#" +model model { + int Int @id + string String? + bigint BigInt? + float Float? + bytes Bytes? + bool Boolean? + dt DateTime? +}"#; + pub(crate) const ENUM_SCHEMA: &str = r#" model model { id Int @id @@ -30,12 +41,14 @@ enum MyFancyEnum { pub(crate) const RELATION_SCHEMA: &str = r#" model parent { id Int @id + nullable String? children child[] } model child { id Int @id + nullable String? parent_id Int? parent parent? @relation(fields: [parent_id], references: [id])