diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 5578517ec3594..0f45d51835f41 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,16 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - apt-get update - apt-get install -y protobuf-compiler + RETRY="ci/scripts/retry" + "${RETRY}" apt-get update + "${RETRY}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | + RETRY="ci/scripts/retry" echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} - rustup default ${{ inputs.rust-version }} - rustup component add rustfmt + "${RETRY}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY}" rustup default ${{ inputs.rust-version }} + "${RETRY}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index edaa49ec6e7ec..39b7b2b178570 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -521,7 +521,7 @@ jobs: run: taplo format --check config-docs-check: - name: check configs.md is up-to-date + name: check configs.md and ***_functions.md is up-to-date needs: [ linux-build-lib ] runs-on: ubuntu-latest container: @@ -542,6 +542,11 @@ jobs: # If you encounter an error, run './dev/update_config_docs.sh' and commit ./dev/update_config_docs.sh git diff --exit-code + - name: Check if any of the ***_functions.md has been modified + run: | + # If you encounter an error, run './dev/update_function_docs.sh' and commit + ./dev/update_function_docs.sh + git diff --exit-code # Verify MSRV for the crates which are directly used by other projects: # - datafusion @@ -569,9 +574,9 @@ jobs: # # To reproduce: # 1. Install the version of Rust that is failing. Example: - # rustup install 1.78.0 + # rustup install 1.79.0 # 2. Run the command that failed with that version. Example: - # cargo +1.78.0 check -p datafusion + # cargo +1.79.0 check -p datafusion # # To resolve, either: # 1. Change your code to use older Rust features, diff --git a/Cargo.toml b/Cargo.toml index 6299921779130..0a7184ad2d99d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,8 +58,8 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.78" -version = "42.0.0" +rust-version = "1.79" +version = "42.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -70,51 +70,51 @@ version = "42.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.0.0", features = [ +arrow = { version = "53.2.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.0.0", default-features = false, features = [ +arrow-array = { version = "53.2.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.0.0", default-features = false } -arrow-flight = { version = "53.0.0", features = [ +arrow-buffer = { version = "53.2.0", default-features = false } +arrow-flight = { version = "53.2.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.0.0", default-features = false, features = [ +arrow-ipc = { version = "53.2.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.0.0", default-features = false } -arrow-schema = { version = "53.0.0", default-features = false } -arrow-string = { version = "53.0.0", default-features = false } +arrow-ord = { version = "53.2.0", default-features = false } +arrow-schema = { version = "53.2.0", default-features = false } +arrow-string = { version = "53.2.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "42.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "42.0.0" } -datafusion-common = { path = "datafusion/common", version = "42.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "42.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "42.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "42.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "42.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "42.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "42.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "42.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "42.0.0" } +datafusion = { path = "datafusion/core", version = "42.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "42.1.0" } +datafusion-common = { path = "datafusion/common", version = "42.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "42.1.0" } +datafusion-expr = { path = "datafusion/expr", version = "42.1.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "42.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "42.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "42.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "42.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "42.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "42.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "42.1.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "42.1.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" @@ -126,7 +126,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.0.0", default-features = false, features = [ +parquet = { version = "53.2.0", default-features = false, features = [ "arrow", "async", "object_store", @@ -137,7 +137,7 @@ prost = "0.13.1" prost-derive = "0.13.1" rand = "0.8" regex = "1.8" -rstest = "0.22.0" +rstest = "0.23.0" serde_json = "1" sqlparser = { version = "0.51.0", features = ["visitor"] } tempfile = "3" @@ -169,3 +169,4 @@ large_futures = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unused_qualifications = "deny" diff --git a/README.md b/README.md index bb8526c24e2cb..f89935d597c2f 100644 --- a/README.md +++ b/README.md @@ -42,14 +42,23 @@ DataFusion is an extensible query engine written in [Rust] that -uses [Apache Arrow] as its in-memory format. DataFusion's target users are -developers building fast and feature rich database and analytic systems, -customized to particular workloads. See [use cases] for examples. +uses [Apache Arrow] as its in-memory format. -"Out of the box," DataFusion offers [SQL] and [`Dataframe`] APIs, -excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, -extensive customization, and a great community. -[Python Bindings] are also available. +This crate provides libraries and binaries for developers building fast and +feature rich database and analytic systems, customized to particular workloads. +See [use cases] for examples. The following related subprojects target end users: + +- [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame + queries. +- [DataFusion Ray](https://github.com/apache/datafusion-ray/) provides a distributed version of DataFusion that scales + out on Ray clusters. +- [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on + DataFusion. + +"Out of the box," +DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and +a great community. DataFusion features a full query planner, a columnar, streaming, multi-threaded, vectorized execution engine, and partitioned data sources. You can @@ -125,3 +134,8 @@ For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81 If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) + +## DataFusion API evolution policy + +Public methods in Apache DataFusion are subject to evolve as part of the API lifecycle. +Deprecated methods will be phased out in accordance with the [policy](https://datafusion.apache.org/library-user-guide/api-health.html), ensuring the API is stable and healthy. diff --git a/benchmarks/README.md b/benchmarks/README.md index afaf28bb75769..a9aa1afb97a1c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -330,6 +330,16 @@ steps. The tests sort the entire dataset using several different sort orders. +## IMDB + +Run Join Order Benchmark (JOB) on IMDB dataset. + +The Internet Movie Database (IMDB) dataset contains real-world movie data. Unlike synthetic datasets like TPCH, which assume uniform data distribution and uncorrelated columns, the IMDB dataset includes skewed data and correlated columns (which are common for real dataset), making it more suitable for testing query optimizers, particularly for cardinality estimation. + +This benchmark is derived from [Join Order Benchmark](https://github.com/gregrahn/join-order-benchmark). + +See paper [How Good Are Query Optimizers, Really](http://www.vldb.org/pvldb/vol9/p204-leis.pdf) for more details. + ## TPCH Run the tpch benchmark. @@ -342,6 +352,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## External Aggregation + +Run the benchmark for aggregations with limited memory. + +When the memory limit is exceeded, the aggregation intermediate results will be spilled to disk, and finally read back with sort-merge. + +External aggregation benchmarks run several aggregation queries with different memory limits, on TPCH `lineitem` table. Queries can be found in [`external_aggr.rs`](src/bin/external_aggr.rs). + +This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. + +### External Aggregation Example Runs +1. Run all queries with predefined memory limits: +```bash +# Under 'benchmarks/' directory +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' +``` + +2. Run a query with specific memory limit: +```bash +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M +``` + +3. Run all queries with `bench.sh` script: +```bash +./bench.sh data external_aggr +./bench.sh run external_aggr +``` + # Older Benchmarks diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 24efab6c6ca56..47c5d1261605b 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -78,6 +78,7 @@ sort: Benchmark of sorting speed clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +external_aggr: External aggregation benchmark ********** * Supported Configuration (Environment Variables) @@ -170,6 +171,10 @@ main() { imdb) data_imdb ;; + external_aggr) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -211,6 +216,8 @@ main() { run_clickbench_1 run_clickbench_partitioned run_clickbench_extended + run_imdb + run_external_aggr ;; tpch) run_tpch "1" @@ -239,6 +246,12 @@ main() { clickbench_extended) run_clickbench_extended ;; + imdb) + run_imdb + ;; + external_aggr) + run_external_aggr + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -353,7 +366,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -361,7 +374,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } @@ -510,7 +523,31 @@ data_imdb() { fi } +# Runs the imdb benchmark +run_imdb() { + IMDB_DIR="${DATA_DIR}/imdb" + + RESULTS_FILE="${RESULTS_DIR}/imdb.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running imdb benchmark..." + $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" +} +# Runs the external aggregation benchmark +run_external_aggr() { + # Use TPC-H SF1 dataset + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/external_aggr.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running external aggregation benchmark..." + + # Only parquet is supported. + # Since per-operator memory limit is calculated as (total-memory-limit / + # number-of-partitions), and by default `--partitions` is set to number of + # CPU cores, we set a constant number of partitions to prevent this + # benchmark to fail on some machines. + $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" +} compare_benchmarks() { diff --git a/benchmarks/queries/imdb/10a.sql b/benchmarks/queries/imdb/10a.sql new file mode 100644 index 0000000000000..95b049b774799 --- /dev/null +++ b/benchmarks/queries/imdb/10a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS uncredited_voiced_character, MIN(t.title) AS russian_movie FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(voice)%' and ci.note like '%(uncredited)%' AND cn.country_code = '[ru]' AND rt.role = 'actor' AND t.production_year > 2005 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/10b.sql b/benchmarks/queries/imdb/10b.sql new file mode 100644 index 0000000000000..c321536314129 --- /dev/null +++ b/benchmarks/queries/imdb/10b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character, MIN(t.title) AS russian_mov_with_actor_producer FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(producer)%' AND cn.country_code = '[ru]' AND rt.role = 'actor' AND t.production_year > 2010 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/10c.sql b/benchmarks/queries/imdb/10c.sql new file mode 100644 index 0000000000000..b862cf4fa7ac2 --- /dev/null +++ b/benchmarks/queries/imdb/10c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character, MIN(t.title) AS movie_with_american_producer FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(producer)%' AND cn.country_code = '[us]' AND t.production_year > 1990 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/11a.sql b/benchmarks/queries/imdb/11a.sql new file mode 100644 index 0000000000000..f835968e900b8 --- /dev/null +++ b/benchmarks/queries/imdb/11a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS non_polish_sequel_movie FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11b.sql b/benchmarks/queries/imdb/11b.sql new file mode 100644 index 0000000000000..2411e19ea6088 --- /dev/null +++ b/benchmarks/queries/imdb/11b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS sequel_movie FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follows%' AND mc.note IS NULL AND t.production_year = 1998 and t.title like '%Money%' AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11c.sql b/benchmarks/queries/imdb/11c.sql new file mode 100644 index 0000000000000..3bf7946789184 --- /dev/null +++ b/benchmarks/queries/imdb/11c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' and (cn.name like '20th Century Fox%' or cn.name like 'Twentieth Century Fox%') AND ct.kind != 'production companies' and ct.kind is not NULL AND k.keyword in ('sequel', 'revenge', 'based-on-novel') AND mc.note is not NULL AND t.production_year > 1950 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11d.sql b/benchmarks/queries/imdb/11d.sql new file mode 100644 index 0000000000000..0bc33e1d6e88a --- /dev/null +++ b/benchmarks/queries/imdb/11d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND ct.kind != 'production companies' and ct.kind is not NULL AND k.keyword in ('sequel', 'revenge', 'based-on-novel') AND mc.note is not NULL AND t.production_year > 1950 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/12a.sql b/benchmarks/queries/imdb/12a.sql new file mode 100644 index 0000000000000..22add74bd55d6 --- /dev/null +++ b/benchmarks/queries/imdb/12a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS drama_horror_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code = '[us]' AND ct.kind = 'production companies' AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Drama', 'Horror') AND mi_idx.info > '8.0' AND t.production_year between 2005 and 2008 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/12b.sql b/benchmarks/queries/imdb/12b.sql new file mode 100644 index 0000000000000..fc30ad550d10f --- /dev/null +++ b/benchmarks/queries/imdb/12b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS budget, MIN(t.title) AS unsuccsessful_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code ='[us]' AND ct.kind is not NULL and (ct.kind ='production companies' or ct.kind = 'distributors') AND it1.info ='budget' AND it2.info ='bottom 10 rank' AND t.production_year >2000 AND (t.title LIKE 'Birdemic%' OR t.title LIKE '%Movie%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/12c.sql b/benchmarks/queries/imdb/12c.sql new file mode 100644 index 0000000000000..64a340b2381ef --- /dev/null +++ b/benchmarks/queries/imdb/12c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS mainstream_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code = '[us]' AND ct.kind = 'production companies' AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Drama', 'Horror', 'Western', 'Family') AND mi_idx.info > '7.0' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/13a.sql b/benchmarks/queries/imdb/13a.sql new file mode 100644 index 0000000000000..95eb439d1e226 --- /dev/null +++ b/benchmarks/queries/imdb/13a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(miidx.info) AS rating, MIN(t.title) AS german_movie FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[de]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13b.sql b/benchmarks/queries/imdb/13b.sql new file mode 100644 index 0000000000000..4b6f75ab0ae66 --- /dev/null +++ b/benchmarks/queries/imdb/13b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND t.title != '' AND (t.title LIKE '%Champion%' OR t.title LIKE '%Loser%') AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13c.sql b/benchmarks/queries/imdb/13c.sql new file mode 100644 index 0000000000000..9e8c92327bd5b --- /dev/null +++ b/benchmarks/queries/imdb/13c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND t.title != '' AND (t.title LIKE 'Champion%' OR t.title LIKE 'Loser%') AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13d.sql b/benchmarks/queries/imdb/13d.sql new file mode 100644 index 0000000000000..a8bc567cabe14 --- /dev/null +++ b/benchmarks/queries/imdb/13d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/14a.sql b/benchmarks/queries/imdb/14a.sql new file mode 100644 index 0000000000000..af1a7c8983a62 --- /dev/null +++ b/benchmarks/queries/imdb/14a.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS northern_dark_movie FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind = 'movie' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2010 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/14b.sql b/benchmarks/queries/imdb/14b.sql new file mode 100644 index 0000000000000..c606ebc73dd48 --- /dev/null +++ b/benchmarks/queries/imdb/14b.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS western_dark_production FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title') AND kt.kind = 'movie' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info > '6.0' AND t.production_year > 2010 and (t.title like '%murder%' or t.title like '%Murder%' or t.title like '%Mord%') AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/14c.sql b/benchmarks/queries/imdb/14c.sql new file mode 100644 index 0000000000000..2a6dffde26393 --- /dev/null +++ b/benchmarks/queries/imdb/14c.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS north_european_dark_production FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword is not null and k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/15a.sql b/benchmarks/queries/imdb/15a.sql new file mode 100644 index 0000000000000..1d052f0044267 --- /dev/null +++ b/benchmarks/queries/imdb/15a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS internet_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year > 2000 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15b.sql b/benchmarks/queries/imdb/15b.sql new file mode 100644 index 0000000000000..21c81358fa7a8 --- /dev/null +++ b/benchmarks/queries/imdb/15b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS youtube_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' and cn.name = 'YouTube' AND it1.info = 'release dates' AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year between 2005 and 2010 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15c.sql b/benchmarks/queries/imdb/15c.sql new file mode 100644 index 0000000000000..2d08c52039743 --- /dev/null +++ b/benchmarks/queries/imdb/15c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS modern_american_internet_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 1990 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15d.sql b/benchmarks/queries/imdb/15d.sql new file mode 100644 index 0000000000000..040e9815d86ca --- /dev/null +++ b/benchmarks/queries/imdb/15d.sql @@ -0,0 +1 @@ +SELECT MIN(at.title) AS aka_title, MIN(t.title) AS internet_movie_title FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mi.note like '%internet%' AND t.production_year > 1990 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/16a.sql b/benchmarks/queries/imdb/16a.sql new file mode 100644 index 0000000000000..aaa0020269d28 --- /dev/null +++ b/benchmarks/queries/imdb/16a.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr >= 50 AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16b.sql b/benchmarks/queries/imdb/16b.sql new file mode 100644 index 0000000000000..c6c0bef319de2 --- /dev/null +++ b/benchmarks/queries/imdb/16b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16c.sql b/benchmarks/queries/imdb/16c.sql new file mode 100644 index 0000000000000..5c3b35752195a --- /dev/null +++ b/benchmarks/queries/imdb/16c.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16d.sql b/benchmarks/queries/imdb/16d.sql new file mode 100644 index 0000000000000..c9e1b5f25ce55 --- /dev/null +++ b/benchmarks/queries/imdb/16d.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr >= 5 AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17a.sql b/benchmarks/queries/imdb/17a.sql new file mode 100644 index 0000000000000..e854a957e4294 --- /dev/null +++ b/benchmarks/queries/imdb/17a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_american_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND n.name LIKE 'B%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17b.sql b/benchmarks/queries/imdb/17b.sql new file mode 100644 index 0000000000000..903f2196b2783 --- /dev/null +++ b/benchmarks/queries/imdb/17b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE 'Z%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17c.sql b/benchmarks/queries/imdb/17c.sql new file mode 100644 index 0000000000000..a96faa0b43390 --- /dev/null +++ b/benchmarks/queries/imdb/17c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE 'X%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17d.sql b/benchmarks/queries/imdb/17d.sql new file mode 100644 index 0000000000000..73e1f2c309763 --- /dev/null +++ b/benchmarks/queries/imdb/17d.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE '%Bert%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17e.sql b/benchmarks/queries/imdb/17e.sql new file mode 100644 index 0000000000000..65ea73ed05102 --- /dev/null +++ b/benchmarks/queries/imdb/17e.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17f.sql b/benchmarks/queries/imdb/17f.sql new file mode 100644 index 0000000000000..542233d63e9dd --- /dev/null +++ b/benchmarks/queries/imdb/17f.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE '%B%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/18a.sql b/benchmarks/queries/imdb/18a.sql new file mode 100644 index 0000000000000..275e04bdb1848 --- /dev/null +++ b/benchmarks/queries/imdb/18a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(producer)', '(executive producer)') AND it1.info = 'budget' AND it2.info = 'votes' AND n.gender = 'm' and n.name like '%Tim%' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/18b.sql b/benchmarks/queries/imdb/18b.sql new file mode 100644 index 0000000000000..3ae40ed93d2f3 --- /dev/null +++ b/benchmarks/queries/imdb/18b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Horror', 'Thriller') and mi.note is NULL AND mi_idx.info > '8.0' AND n.gender is not null and n.gender = 'f' AND t.production_year between 2008 and 2014 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/18c.sql b/benchmarks/queries/imdb/18c.sql new file mode 100644 index 0000000000000..01f28ea527feb --- /dev/null +++ b/benchmarks/queries/imdb/18c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/19a.sql b/benchmarks/queries/imdb/19a.sql new file mode 100644 index 0000000000000..ceaae671fd201 --- /dev/null +++ b/benchmarks/queries/imdb/19a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%Ang%' AND rt.role ='actress' AND t.production_year between 2005 and 2009 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19b.sql b/benchmarks/queries/imdb/19b.sql new file mode 100644 index 0000000000000..62e852ba3ec61 --- /dev/null +++ b/benchmarks/queries/imdb/19b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS kung_fu_panda FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note = '(voice)' AND cn.country_code ='[us]' AND it.info = 'release dates' AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND mi.info is not null and (mi.info like 'Japan:%2007%' or mi.info like 'USA:%2008%') AND n.gender ='f' and n.name like '%Angel%' AND rt.role ='actress' AND t.production_year between 2007 and 2008 and t.title like '%Kung%Fu%Panda%' AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19c.sql b/benchmarks/queries/imdb/19c.sql new file mode 100644 index 0000000000000..6885af5012fc9 --- /dev/null +++ b/benchmarks/queries/imdb/19c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19d.sql b/benchmarks/queries/imdb/19d.sql new file mode 100644 index 0000000000000..06fcc76ba7adc --- /dev/null +++ b/benchmarks/queries/imdb/19d.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND n.gender ='f' AND rt.role ='actress' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/1a.sql b/benchmarks/queries/imdb/1a.sql new file mode 100644 index 0000000000000..07b3516388570 --- /dev/null +++ b/benchmarks/queries/imdb/1a.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'top 250 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%' or mc.note like '%(presents)%') AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1b.sql b/benchmarks/queries/imdb/1b.sql new file mode 100644 index 0000000000000..f2901e8b52621 --- /dev/null +++ b/benchmarks/queries/imdb/1b.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'bottom 10 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' AND t.production_year between 2005 and 2010 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1c.sql b/benchmarks/queries/imdb/1c.sql new file mode 100644 index 0000000000000..94e66c30aa144 --- /dev/null +++ b/benchmarks/queries/imdb/1c.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'top 250 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%') AND t.production_year >2010 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1d.sql b/benchmarks/queries/imdb/1d.sql new file mode 100644 index 0000000000000..52f58e80c8113 --- /dev/null +++ b/benchmarks/queries/imdb/1d.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'bottom 10 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' AND t.production_year >2000 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/20a.sql b/benchmarks/queries/imdb/20a.sql new file mode 100644 index 0000000000000..2a1c269d6a51c --- /dev/null +++ b/benchmarks/queries/imdb/20a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS complete_downey_ironman_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND kt.kind = 'movie' AND t.production_year > 1950 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/20b.sql b/benchmarks/queries/imdb/20b.sql new file mode 100644 index 0000000000000..4c2455a52eb12 --- /dev/null +++ b/benchmarks/queries/imdb/20b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS complete_downey_ironman_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND kt.kind = 'movie' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/20c.sql b/benchmarks/queries/imdb/20c.sql new file mode 100644 index 0000000000000..b85b22f6b4f2c --- /dev/null +++ b/benchmarks/queries/imdb/20c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS cast_member, MIN(t.title) AS complete_dynamic_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/21a.sql b/benchmarks/queries/imdb/21a.sql new file mode 100644 index 0000000000000..8a66a00be6cb9 --- /dev/null +++ b/benchmarks/queries/imdb/21a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/21b.sql b/benchmarks/queries/imdb/21b.sql new file mode 100644 index 0000000000000..90d3a5a4c0786 --- /dev/null +++ b/benchmarks/queries/imdb/21b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS german_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Germany', 'German') AND t.production_year BETWEEN 2000 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/21c.sql b/benchmarks/queries/imdb/21c.sql new file mode 100644 index 0000000000000..16a42ae6f426f --- /dev/null +++ b/benchmarks/queries/imdb/21c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') AND t.production_year BETWEEN 1950 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/22a.sql b/benchmarks/queries/imdb/22a.sql new file mode 100644 index 0000000000000..e513799698c5c --- /dev/null +++ b/benchmarks/queries/imdb/22a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Germany', 'German', 'USA', 'American') AND mi_idx.info < '7.0' AND t.production_year > 2008 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22b.sql b/benchmarks/queries/imdb/22b.sql new file mode 100644 index 0000000000000..f98d0ea8099d4 --- /dev/null +++ b/benchmarks/queries/imdb/22b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Germany', 'German', 'USA', 'American') AND mi_idx.info < '7.0' AND t.production_year > 2009 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22c.sql b/benchmarks/queries/imdb/22c.sql new file mode 100644 index 0000000000000..cf757956e0dec --- /dev/null +++ b/benchmarks/queries/imdb/22c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22d.sql b/benchmarks/queries/imdb/22d.sql new file mode 100644 index 0000000000000..a47feeb051575 --- /dev/null +++ b/benchmarks/queries/imdb/22d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/23a.sql b/benchmarks/queries/imdb/23a.sql new file mode 100644 index 0000000000000..724da913b51a9 --- /dev/null +++ b/benchmarks/queries/imdb/23a.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND kt.kind in ('movie') AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/23b.sql b/benchmarks/queries/imdb/23b.sql new file mode 100644 index 0000000000000..e39f0ecc28a27 --- /dev/null +++ b/benchmarks/queries/imdb/23b.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_nerdy_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND k.keyword in ('nerd', 'loner', 'alienation', 'dignity') AND kt.kind in ('movie') AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/23c.sql b/benchmarks/queries/imdb/23c.sql new file mode 100644 index 0000000000000..839d762d05332 --- /dev/null +++ b/benchmarks/queries/imdb/23c.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND kt.kind in ('movie', 'tv movie', 'video movie', 'video game') AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 1990 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/24a.sql b/benchmarks/queries/imdb/24a.sql new file mode 100644 index 0000000000000..8f10621e02092 --- /dev/null +++ b/benchmarks/queries/imdb/24a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS voiced_action_movie_jap_eng FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat') AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND ci.movie_id = mk.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/24b.sql b/benchmarks/queries/imdb/24b.sql new file mode 100644 index 0000000000000..d8a2836000b2a --- /dev/null +++ b/benchmarks/queries/imdb/24b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS kung_fu_panda FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND cn.name = 'DreamWorks Animation' AND it.info = 'release dates' AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat', 'computer-animated-movie') AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2010 AND t.title like 'Kung Fu Panda%' AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND ci.movie_id = mk.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25a.sql b/benchmarks/queries/imdb/25a.sql new file mode 100644 index 0000000000000..bc55cc01d26b5 --- /dev/null +++ b/benchmarks/queries/imdb/25a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') AND mi.info = 'Horror' AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25b.sql b/benchmarks/queries/imdb/25b.sql new file mode 100644 index 0000000000000..3457655bb9eb9 --- /dev/null +++ b/benchmarks/queries/imdb/25b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') AND mi.info = 'Horror' AND n.gender = 'm' AND t.production_year > 2010 AND t.title like 'Vampire%' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25c.sql b/benchmarks/queries/imdb/25c.sql new file mode 100644 index 0000000000000..cf56a313d8613 --- /dev/null +++ b/benchmarks/queries/imdb/25c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/26a.sql b/benchmarks/queries/imdb/26a.sql new file mode 100644 index 0000000000000..b431f204c6dc3 --- /dev/null +++ b/benchmarks/queries/imdb/26a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(n.name) AS playing_actor, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND mi_idx.info > '7.0' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/26b.sql b/benchmarks/queries/imdb/26b.sql new file mode 100644 index 0000000000000..882d234d77e00 --- /dev/null +++ b/benchmarks/queries/imdb/26b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'fight') AND kt.kind = 'movie' AND mi_idx.info > '8.0' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/26c.sql b/benchmarks/queries/imdb/26c.sql new file mode 100644 index 0000000000000..4b9eae0b76332 --- /dev/null +++ b/benchmarks/queries/imdb/26c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/27a.sql b/benchmarks/queries/imdb/27a.sql new file mode 100644 index 0000000000000..239673cd8147e --- /dev/null +++ b/benchmarks/queries/imdb/27a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind = 'complete' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/27b.sql b/benchmarks/queries/imdb/27b.sql new file mode 100644 index 0000000000000..4bf85260f22de --- /dev/null +++ b/benchmarks/queries/imdb/27b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind = 'complete' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') AND t.production_year = 1998 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/27c.sql b/benchmarks/queries/imdb/27c.sql new file mode 100644 index 0000000000000..dc26ebff68513 --- /dev/null +++ b/benchmarks/queries/imdb/27c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like 'complete%' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') AND t.production_year BETWEEN 1950 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/28a.sql b/benchmarks/queries/imdb/28a.sql new file mode 100644 index 0000000000000..8cb1177386da5 --- /dev/null +++ b/benchmarks/queries/imdb/28a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'crew' AND cct2.kind != 'complete+verified' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/28b.sql b/benchmarks/queries/imdb/28b.sql new file mode 100644 index 0000000000000..10f43c8982261 --- /dev/null +++ b/benchmarks/queries/imdb/28b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'crew' AND cct2.kind != 'complete+verified' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Germany', 'Swedish', 'German') AND mi_idx.info > '6.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/28c.sql b/benchmarks/queries/imdb/28c.sql new file mode 100644 index 0000000000000..6b2e4047ae8a1 --- /dev/null +++ b/benchmarks/queries/imdb/28c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'cast' AND cct2.kind = 'complete' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29a.sql b/benchmarks/queries/imdb/29a.sql new file mode 100644 index 0000000000000..3033acbe6cf39 --- /dev/null +++ b/benchmarks/queries/imdb/29a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND chn.name = 'Queen' AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'trivia' AND k.keyword = 'computer-animation' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.title = 'Shrek 2' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29b.sql b/benchmarks/queries/imdb/29b.sql new file mode 100644 index 0000000000000..88d50fc7b783a --- /dev/null +++ b/benchmarks/queries/imdb/29b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND chn.name = 'Queen' AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'height' AND k.keyword = 'computer-animation' AND mi.info like 'USA:%200%' AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.title = 'Shrek 2' AND t.production_year between 2000 and 2005 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29c.sql b/benchmarks/queries/imdb/29c.sql new file mode 100644 index 0000000000000..cb951781827c9 --- /dev/null +++ b/benchmarks/queries/imdb/29c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'trivia' AND k.keyword = 'computer-animation' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/2a.sql b/benchmarks/queries/imdb/2a.sql new file mode 100644 index 0000000000000..f3ef4db75fea3 --- /dev/null +++ b/benchmarks/queries/imdb/2a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[de]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2b.sql b/benchmarks/queries/imdb/2b.sql new file mode 100644 index 0000000000000..82b2123fbccde --- /dev/null +++ b/benchmarks/queries/imdb/2b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[nl]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2c.sql b/benchmarks/queries/imdb/2c.sql new file mode 100644 index 0000000000000..b5f9b75dd68bb --- /dev/null +++ b/benchmarks/queries/imdb/2c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[sm]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2d.sql b/benchmarks/queries/imdb/2d.sql new file mode 100644 index 0000000000000..4a27919465488 --- /dev/null +++ b/benchmarks/queries/imdb/2d.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/30a.sql b/benchmarks/queries/imdb/30a.sql new file mode 100644 index 0000000000000..698872fa8337e --- /dev/null +++ b/benchmarks/queries/imdb/30a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/30b.sql b/benchmarks/queries/imdb/30b.sql new file mode 100644 index 0000000000000..5fdb8493496ce --- /dev/null +++ b/benchmarks/queries/imdb/30b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_gore_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/30c.sql b/benchmarks/queries/imdb/30c.sql new file mode 100644 index 0000000000000..a18087e392220 --- /dev/null +++ b/benchmarks/queries/imdb/30c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/31a.sql b/benchmarks/queries/imdb/31a.sql new file mode 100644 index 0000000000000..7dd855011f2af --- /dev/null +++ b/benchmarks/queries/imdb/31a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/31b.sql b/benchmarks/queries/imdb/31b.sql new file mode 100644 index 0000000000000..3be5680f7d001 --- /dev/null +++ b/benchmarks/queries/imdb/31b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mc.note like '%(Blu-ray)%' AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/31c.sql b/benchmarks/queries/imdb/31c.sql new file mode 100644 index 0000000000000..156ea2d5eee29 --- /dev/null +++ b/benchmarks/queries/imdb/31c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/32a.sql b/benchmarks/queries/imdb/32a.sql new file mode 100644 index 0000000000000..9647fb71065d9 --- /dev/null +++ b/benchmarks/queries/imdb/32a.sql @@ -0,0 +1 @@ +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 WHERE k.keyword ='10,000-mile-club' AND mk.keyword_id = k.id AND t1.id = mk.movie_id AND ml.movie_id = t1.id AND ml.linked_movie_id = t2.id AND lt.id = ml.link_type_id AND mk.movie_id = t1.id; diff --git a/benchmarks/queries/imdb/32b.sql b/benchmarks/queries/imdb/32b.sql new file mode 100644 index 0000000000000..6d096ab434053 --- /dev/null +++ b/benchmarks/queries/imdb/32b.sql @@ -0,0 +1 @@ +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 WHERE k.keyword ='character-name-in-title' AND mk.keyword_id = k.id AND t1.id = mk.movie_id AND ml.movie_id = t1.id AND ml.linked_movie_id = t2.id AND lt.id = ml.link_type_id AND mk.movie_id = t1.id; diff --git a/benchmarks/queries/imdb/33a.sql b/benchmarks/queries/imdb/33a.sql new file mode 100644 index 0000000000000..24aac4e207970 --- /dev/null +++ b/benchmarks/queries/imdb/33a.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code = '[us]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series') AND kt2.kind in ('tv series') AND lt.link in ('sequel', 'follows', 'followed by') AND mi_idx2.info < '3.0' AND t2.production_year between 2005 and 2008 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/33b.sql b/benchmarks/queries/imdb/33b.sql new file mode 100644 index 0000000000000..fe6fd75a69485 --- /dev/null +++ b/benchmarks/queries/imdb/33b.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code = '[nl]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series') AND kt2.kind in ('tv series') AND lt.link LIKE '%follow%' AND mi_idx2.info < '3.0' AND t2.production_year = 2007 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/33c.sql b/benchmarks/queries/imdb/33c.sql new file mode 100644 index 0000000000000..c9f0907d3f902 --- /dev/null +++ b/benchmarks/queries/imdb/33c.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code != '[us]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series', 'episode') AND kt2.kind in ('tv series', 'episode') AND lt.link in ('sequel', 'follows', 'followed by') AND mi_idx2.info < '3.5' AND t2.production_year between 2000 and 2010 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/3a.sql b/benchmarks/queries/imdb/3a.sql new file mode 100644 index 0000000000000..231c957be2078 --- /dev/null +++ b/benchmarks/queries/imdb/3a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year > 2005 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/3b.sql b/benchmarks/queries/imdb/3b.sql new file mode 100644 index 0000000000000..fd21efc81014c --- /dev/null +++ b/benchmarks/queries/imdb/3b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Bulgaria') AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/3c.sql b/benchmarks/queries/imdb/3c.sql new file mode 100644 index 0000000000000..5f34232a2e61c --- /dev/null +++ b/benchmarks/queries/imdb/3c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND t.production_year > 1990 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/4a.sql b/benchmarks/queries/imdb/4a.sql new file mode 100644 index 0000000000000..636afab02c8ac --- /dev/null +++ b/benchmarks/queries/imdb/4a.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '5.0' AND t.production_year > 2005 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/4b.sql b/benchmarks/queries/imdb/4b.sql new file mode 100644 index 0000000000000..ebd3e89920604 --- /dev/null +++ b/benchmarks/queries/imdb/4b.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '9.0' AND t.production_year > 2010 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/4c.sql b/benchmarks/queries/imdb/4c.sql new file mode 100644 index 0000000000000..309281200f985 --- /dev/null +++ b/benchmarks/queries/imdb/4c.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '2.0' AND t.production_year > 1990 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/5a.sql b/benchmarks/queries/imdb/5a.sql new file mode 100644 index 0000000000000..04aae9881f7e5 --- /dev/null +++ b/benchmarks/queries/imdb/5a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS typical_european_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note like '%(theatrical)%' and mc.note like '%(France)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year > 2005 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/5b.sql b/benchmarks/queries/imdb/5b.sql new file mode 100644 index 0000000000000..f03a519d61b3f --- /dev/null +++ b/benchmarks/queries/imdb/5b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS american_vhs_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note like '%(VHS)%' and mc.note like '%(USA)%' and mc.note like '%(1994)%' AND mi.info IN ('USA', 'America') AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/5c.sql b/benchmarks/queries/imdb/5c.sql new file mode 100644 index 0000000000000..2705e7e2c7a05 --- /dev/null +++ b/benchmarks/queries/imdb/5c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS american_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note not like '%(TV)%' and mc.note like '%(USA)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND t.production_year > 1990 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/6a.sql b/benchmarks/queries/imdb/6a.sql new file mode 100644 index 0000000000000..34b3a6da5fd2f --- /dev/null +++ b/benchmarks/queries/imdb/6a.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2010 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6b.sql b/benchmarks/queries/imdb/6b.sql new file mode 100644 index 0000000000000..1233c41e66b0c --- /dev/null +++ b/benchmarks/queries/imdb/6b.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2014 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6c.sql b/benchmarks/queries/imdb/6c.sql new file mode 100644 index 0000000000000..d1f97746e15e5 --- /dev/null +++ b/benchmarks/queries/imdb/6c.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2014 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6d.sql b/benchmarks/queries/imdb/6d.sql new file mode 100644 index 0000000000000..07729510a454e --- /dev/null +++ b/benchmarks/queries/imdb/6d.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6e.sql b/benchmarks/queries/imdb/6e.sql new file mode 100644 index 0000000000000..2e77873fd81df --- /dev/null +++ b/benchmarks/queries/imdb/6e.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6f.sql b/benchmarks/queries/imdb/6f.sql new file mode 100644 index 0000000000000..603901129107d --- /dev/null +++ b/benchmarks/queries/imdb/6f.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/7a.sql b/benchmarks/queries/imdb/7a.sql new file mode 100644 index 0000000000000..c6b26ce36f11a --- /dev/null +++ b/benchmarks/queries/imdb/7a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name LIKE '%a%' AND it.info ='mini biography' AND lt.link ='features' AND n.name_pcode_cf BETWEEN 'A' AND 'F' AND (n.gender='m' OR (n.gender = 'f' AND n.name LIKE 'B%')) AND pi.note ='Volker Boehm' AND t.production_year BETWEEN 1980 AND 1995 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/7b.sql b/benchmarks/queries/imdb/7b.sql new file mode 100644 index 0000000000000..4e4f6e7615cb5 --- /dev/null +++ b/benchmarks/queries/imdb/7b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name LIKE '%a%' AND it.info ='mini biography' AND lt.link ='features' AND n.name_pcode_cf LIKE 'D%' AND n.gender='m' AND pi.note ='Volker Boehm' AND t.production_year BETWEEN 1980 AND 1984 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/7c.sql b/benchmarks/queries/imdb/7c.sql new file mode 100644 index 0000000000000..a399342fae026 --- /dev/null +++ b/benchmarks/queries/imdb/7c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS cast_member_name, MIN(pi.info) AS cast_member_info FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name is not NULL and (an.name LIKE '%a%' or an.name LIKE 'A%') AND it.info ='mini biography' AND lt.link in ('references', 'referenced in', 'features', 'featured in') AND n.name_pcode_cf BETWEEN 'A' AND 'F' AND (n.gender='m' OR (n.gender = 'f' AND n.name LIKE 'A%')) AND pi.note is not NULL AND t.production_year BETWEEN 1980 AND 2010 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/8a.sql b/benchmarks/queries/imdb/8a.sql new file mode 100644 index 0000000000000..66ed05880d5f3 --- /dev/null +++ b/benchmarks/queries/imdb/8a.sql @@ -0,0 +1 @@ +SELECT MIN(an1.name) AS actress_pseudonym, MIN(t.title) AS japanese_movie_dubbed FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE ci.note ='(voice: English version)' AND cn.country_code ='[jp]' AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' AND n1.name like '%Yo%' and n1.name not like '%Yu%' AND rt.role ='actress' AND an1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8b.sql b/benchmarks/queries/imdb/8b.sql new file mode 100644 index 0000000000000..044b5f8e86499 --- /dev/null +++ b/benchmarks/queries/imdb/8b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS acress_pseudonym, MIN(t.title) AS japanese_anime_movie FROM aka_name AS an, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note ='(voice: English version)' AND cn.country_code ='[jp]' AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' and (mc.note like '%(2006)%' or mc.note like '%(2007)%') AND n.name like '%Yo%' and n.name not like '%Yu%' AND rt.role ='actress' AND t.production_year between 2006 and 2007 and (t.title like 'One Piece%' or t.title like 'Dragon Ball Z%') AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8c.sql b/benchmarks/queries/imdb/8c.sql new file mode 100644 index 0000000000000..d02b74c02c5ee --- /dev/null +++ b/benchmarks/queries/imdb/8c.sql @@ -0,0 +1 @@ +SELECT MIN(a1.name) AS writer_pseudo_name, MIN(t.title) AS movie_title FROM aka_name AS a1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE cn.country_code ='[us]' AND rt.role ='writer' AND a1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND a1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8d.sql b/benchmarks/queries/imdb/8d.sql new file mode 100644 index 0000000000000..0834c0ff5cb71 --- /dev/null +++ b/benchmarks/queries/imdb/8d.sql @@ -0,0 +1 @@ +SELECT MIN(an1.name) AS costume_designer_pseudo, MIN(t.title) AS movie_with_costumes FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE cn.country_code ='[us]' AND rt.role ='costume designer' AND an1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/9a.sql b/benchmarks/queries/imdb/9a.sql new file mode 100644 index 0000000000000..593b16213b06b --- /dev/null +++ b/benchmarks/queries/imdb/9a.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS character_name, MIN(t.title) AS movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND n.gender ='f' and n.name like '%Ang%' AND rt.role ='actress' AND t.production_year between 2005 and 2015 AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9b.sql b/benchmarks/queries/imdb/9b.sql new file mode 100644 index 0000000000000..a4933fd6856e8 --- /dev/null +++ b/benchmarks/queries/imdb/9b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note = '(voice)' AND cn.country_code ='[us]' AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND n.gender ='f' and n.name like '%Angel%' AND rt.role ='actress' AND t.production_year between 2007 and 2010 AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9c.sql b/benchmarks/queries/imdb/9c.sql new file mode 100644 index 0000000000000..0be511810cf66 --- /dev/null +++ b/benchmarks/queries/imdb/9c.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9d.sql b/benchmarks/queries/imdb/9d.sql new file mode 100644 index 0000000000000..51262ca5ebae4 --- /dev/null +++ b/benchmarks/queries/imdb/9d.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND n.gender ='f' AND rt.role ='actress' AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 9ce6848a063aa..f7b84116e793a 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -33,7 +33,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -use datafusion_benchmarks::{clickbench, parquet_filter, sort, tpch}; +use datafusion_benchmarks::{clickbench, imdb, parquet_filter, sort, tpch}; #[derive(Debug, StructOpt)] #[structopt(about = "benchmark command")] @@ -43,6 +43,7 @@ enum Options { Clickbench(clickbench::RunOpt), ParquetFilter(parquet_filter::RunOpt), Sort(sort::RunOpt), + Imdb(imdb::RunOpt), } // Main benchmark runner entrypoint @@ -56,5 +57,6 @@ pub async fn main() -> Result<()> { Options::Clickbench(opt) => opt.run().await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, + Options::Imdb(opt) => opt.run().await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs new file mode 100644 index 0000000000000..1bc74e22ccfae --- /dev/null +++ b/benchmarks/src/bin/external_aggr.rs @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! external_aggr binary entrypoint + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::OnceLock; +use structopt::StructOpt; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::FairSpillPool; +use datafusion::execution::memory_pool::{human_readable_size, units}; +use datafusion::execution::runtime_env::RuntimeConfig; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "datafusion-external-aggregation", + about = "DataFusion external aggregation benchmark" +)] +enum ExternalAggrOpt { + Benchmark(ExternalAggrConfig), +} + +#[derive(Debug, StructOpt)] +struct ExternalAggrConfig { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. + #[structopt(long)] + memory_limit: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +/// Query Memory Limits +/// Map query id to predefined memory limits +/// +/// Q1 requires 36MiB for aggregation +/// Memory limits to run: 64MiB, 32MiB, 16MiB +/// Q2 requires 250MiB for aggregation +/// Memory limits to run: 512MiB, 256MiB, 128MiB, 64MiB, 32MiB +static QUERY_MEMORY_LIMITS: OnceLock>> = OnceLock::new(); + +impl ExternalAggrConfig { + const AGGR_TABLES: [&'static str; 1] = ["lineitem"]; + const AGGR_QUERIES: [&'static str; 2] = [ + // Q1: Output size is ~25% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey + FROM lineitem + ) + "#, + // Q2: Output size is ~99% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey, l_suppkey + FROM lineitem + ) + "#, + ]; + + fn init_query_memory_limits() -> &'static HashMap> { + use units::*; + QUERY_MEMORY_LIMITS.get_or_init(|| { + let mut map = HashMap::new(); + map.insert(1, vec![64 * MB, 32 * MB, 16 * MB]); + map.insert(2, vec![512 * MB, 256 * MB, 128 * MB, 64 * MB, 32 * MB]); + map + }) + } + + /// If `--query` and `--memory-limit` is not speicified, run all queries + /// with pre-configured memory limits + /// If only `--query` is specified, run the query with all memory limits + /// for this query + /// If both `--query` and `--memory-limit` are specified, run the query + /// with the specified memory limit + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let memory_limit = match &self.memory_limit { + Some(limit) => Some(Self::parse_memory_limit(limit)?), + None => None, + }; + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::AGGR_QUERIES.len(), + }; + + // Each element is (query_id, memory_limit) + // e.g. [(1, 64_000), (1, 32_000)...] means first run Q1 with 64KiB + // memory limit, next run Q1 with 32KiB memory limit, etc. + let mut query_executions = vec![]; + // Setup `query_executions` + for query_id in query_range { + if query_id > Self::AGGR_QUERIES.len() { + return exec_err!( + "Invalid '--query'(query number) {} for external aggregation benchmark.", + query_id + ); + } + + match memory_limit { + Some(limit) => { + query_executions.push((query_id, limit)); + } + None => { + let memory_limits_table = Self::init_query_memory_limits(); + let memory_limits = memory_limits_table.get(&query_id).unwrap(); + for limit in memory_limits { + query_executions.push((query_id, *limit)); + } + } + } + } + + for (query_id, mem_limit) in query_executions { + benchmark_run.start_new_case(&format!( + "{query_id}({})", + human_readable_size(mem_limit as usize) + )); + + let query_results = self.benchmark_query(query_id, mem_limit).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `AGGR_QUERIES` + async fn benchmark_query( + &self, + query_id: usize, + mem_limit: u64, + ) -> Result> { + let query_name = + format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); + let mut config = self.common.config(); + config + .options_mut() + .execution + .parquet + .schema_force_view_types = self.common.force_view_types; + let runtime_config = RuntimeConfig::new() + .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .build_arc()?; + let ctx = SessionContext::new_with_config_rt(config, runtime_config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::AGGR_QUERIES[query_idx]; + + let result = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "{query_name} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("{query_name} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::AGGR_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } + + /// Parse memory limit from string to number of bytes + /// e.g. '1.5G', '100M' -> 1572864 + fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) + })?; + + match unit { + "K" => Ok((number * 1024.0) as u64), + "M" => Ok((number * 1024.0 * 1024.0) as u64), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), + _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), + } + } +} + +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match ExternalAggrOpt::from_args() { + ExternalAggrOpt::Benchmark(opt) => opt.run().await?, + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!( + ExternalAggrConfig::parse_memory_limit("100K").unwrap(), + 102400 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), + 1572864 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("2G").unwrap(), + 2147483648 + ); + + // Test invalid unit + assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); + } +} diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs index 1bb8cb9d43e4b..1ddeb786a5911 100644 --- a/benchmarks/src/bin/h2o.rs +++ b/benchmarks/src/bin/h2o.rs @@ -26,7 +26,7 @@ use datafusion::datasource::listing::{ use datafusion::datasource::MemTable; use datafusion::prelude::CsvReadOptions; use datafusion::{arrow::util::pretty, error::Result, prelude::SessionContext}; -use datafusion_benchmarks::BenchmarkRun; +use datafusion_benchmarks::util::BenchmarkRun; use std::path::PathBuf; use std::sync::Arc; use structopt::StructOpt; diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 40efb84b05011..13421f8a89a9b 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -34,9 +34,17 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; +#[derive(Debug, StructOpt)] +#[structopt(about = "benchmark command")] +enum BenchmarkSubCommandOpt { + #[structopt(name = "datafusion")] + DataFusionBenchmark(imdb::RunOpt), +} + #[derive(Debug, StructOpt)] #[structopt(name = "IMDB", about = "IMDB Dataset Processing.")] enum ImdbOpt { + Benchmark(BenchmarkSubCommandOpt), Convert(imdb::ConvertOpt), } @@ -44,6 +52,9 @@ enum ImdbOpt { pub async fn main() -> Result<()> { env_logger::init(); match ImdbOpt::from_args() { + ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { + opt.run().await + } ImdbOpt::Convert(opt) => opt.run().await, } } diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 207da4020b588..3564ae82585a6 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -18,6 +18,7 @@ use std::path::Path; use std::path::PathBuf; +use crate::util::{BenchmarkRun, CommonOpt}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -26,8 +27,6 @@ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; use structopt::StructOpt; -use crate::{BenchmarkRun, CommonOpt}; - /// Run the clickbench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and @@ -116,12 +115,15 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; + // configure parquet options let mut config = self.common.config(); - config - .options_mut() - .execution - .parquet - .schema_force_view_types = self.common.force_view_types; + { + let parquet_options = &mut config.options_mut().execution.parquet; + parquet_options.schema_force_view_types = self.common.force_view_types; + // The hits_partitioned dataset specifies string columns + // as binary due to how it was written. Force it to strings + parquet_options.binary_as_string = true; + } let ctx = SessionContext::new_with_config(config); self.register_hits(&ctx).await?; @@ -149,7 +151,7 @@ impl RunOpt { Ok(()) } - /// Registrs the `hits.parquet` as a table named `hits` + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); diff --git a/benchmarks/src/imdb/convert.rs b/benchmarks/src/imdb/convert.rs index c95f7f8bf564f..4e470d711da5d 100644 --- a/benchmarks/src/imdb/convert.rs +++ b/benchmarks/src/imdb/convert.rs @@ -51,11 +51,12 @@ impl ConvertOpt { pub async fn run(self) -> Result<()> { let input_path = self.input_path.to_str().unwrap(); let output_path = self.output_path.to_str().unwrap(); + let config = SessionConfig::new().with_batch_size(self.batch_size); + let ctx = SessionContext::new_with_config(config); for table in IMDB_TABLES { let start = Instant::now(); let schema = get_imdb_table_schema(table); - let input_path = format!("{input_path}/{table}.csv"); let output_path = format!("{output_path}/{table}.parquet"); let options = CsvReadOptions::new() @@ -65,9 +66,6 @@ impl ConvertOpt { .escape(b'\\') .file_extension(".csv"); - let config = SessionConfig::new().with_batch_size(self.batch_size); - let ctx = SessionContext::new_with_config(config); - let mut csv = ctx.read_csv(&input_path, options).await?; // Select all apart from the padding column diff --git a/benchmarks/src/imdb/mod.rs b/benchmarks/src/imdb/mod.rs index 8e2977c0384e6..6a45242e6ff4b 100644 --- a/benchmarks/src/imdb/mod.rs +++ b/benchmarks/src/imdb/mod.rs @@ -17,10 +17,18 @@ //! Benchmark derived from IMDB dataset. -use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::plan_err, + error::Result, +}; mod convert; pub use convert::ConvertOpt; +use std::fs; +mod run; +pub use run::RunOpt; + // we have 21 tables in the IMDB dataset pub const IMDB_TABLES: &[&str] = &[ "aka_name", @@ -51,7 +59,7 @@ pub const IMDB_TABLES: &[&str] = &[ pub fn get_imdb_table_schema(table: &str) -> Schema { match table { "aka_name" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("person_id", DataType::Int32, false), Field::new("name", DataType::Utf8, true), Field::new("imdb_index", DataType::Utf8, true), @@ -61,7 +69,7 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "aka_title" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("title", DataType::Utf8, true), Field::new("imdb_index", DataType::Utf8, true), @@ -75,7 +83,7 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "cast_info" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("person_id", DataType::Int32, false), Field::new("movie_id", DataType::Int32, false), Field::new("person_role_id", DataType::Int32, true), @@ -84,7 +92,7 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("role_id", DataType::Int32, false), ]), "char_name" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("name", DataType::Utf8, false), Field::new("imdb_index", DataType::Utf8, true), Field::new("imdb_id", DataType::Int32, true), @@ -93,11 +101,11 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "comp_cast_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("kind", DataType::Utf8, false), ]), "company_name" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("name", DataType::Utf8, false), Field::new("country_code", DataType::Utf8, true), Field::new("imdb_id", DataType::Int32, true), @@ -106,59 +114,59 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "company_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("kind", DataType::Utf8, true), ]), "complete_cast" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, true), Field::new("subject_id", DataType::Int32, false), Field::new("status_id", DataType::Int32, false), ]), "info_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("info", DataType::Utf8, false), ]), "keyword" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("keyword", DataType::Utf8, false), Field::new("phonetic_code", DataType::Utf8, true), ]), "kind_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("kind", DataType::Utf8, true), ]), "link_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("link", DataType::Utf8, false), ]), "movie_companies" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("company_id", DataType::Int32, false), Field::new("company_type_id", DataType::Int32, false), Field::new("note", DataType::Utf8, true), ]), "movie_info_idx" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("info_type_id", DataType::Int32, false), Field::new("info", DataType::Utf8, false), Field::new("note", DataType::Utf8, true), ]), "movie_keyword" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("keyword_id", DataType::Int32, false), ]), "movie_link" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("linked_movie_id", DataType::Int32, false), Field::new("link_type_id", DataType::Int32, false), ]), "name" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("name", DataType::Utf8, false), Field::new("imdb_index", DataType::Utf8, true), Field::new("imdb_id", DataType::Int32, true), @@ -169,11 +177,11 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "role_type" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("role", DataType::Utf8, false), ]), "title" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("title", DataType::Utf8, false), Field::new("imdb_index", DataType::Utf8, true), Field::new("kind_id", DataType::Int32, false), @@ -187,14 +195,14 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { Field::new("md5sum", DataType::Utf8, true), ]), "movie_info" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("movie_id", DataType::Int32, false), Field::new("info_type_id", DataType::Int32, false), Field::new("info", DataType::Utf8, false), Field::new("note", DataType::Utf8, true), ]), "person_info" => Schema::new(vec![ - Field::new("id", DataType::Int32, false), + Field::new("id", DataType::UInt32, false), Field::new("person_id", DataType::Int32, false), Field::new("info_type_id", DataType::Int32, false), Field::new("info", DataType::Utf8, false), @@ -203,3 +211,26 @@ pub fn get_imdb_table_schema(table: &str) -> Schema { _ => unimplemented!("Schema for table {} is not implemented", table), } } + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(query: &str) -> Result> { + let possibilities = vec![ + format!("queries/imdb/{query}.sql"), + format!("benchmarks/queries/imdb/{query}.sql"), + ]; + let mut errors = vec![]; + for filename in possibilities { + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + } + plan_err!("invalid query. Could not find query: {:?}", errors) +} diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs new file mode 100644 index 0000000000000..fd49606061104 --- /dev/null +++ b/benchmarks/src/imdb/run.rs @@ -0,0 +1,828 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; +use std::sync::Arc; + +use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; +use crate::util::{BenchmarkRun, CommonOpt}; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::csv::CsvFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; + +use log::info; +use structopt::StructOpt; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; + +/// Run the imdb benchmark (a.k.a. JOB). +/// +/// This benchmarks is derived from the [Join Order Benchmark / JOB] proposed in paper [How Good Are Query Optimizers, Really?][1]. +/// The data and answers are downloaded from +/// [2] and [3]. +/// +/// [1]: https://www.vldb.org/pvldb/vol9/p204-leis.pdf +/// [2]: http://homepages.cwi.nl/~boncz/job/imdb.tgz +/// [3]: https://db.in.tum.de/~leis/qo/job.tgz + +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// File format: `csv` or `parquet` + #[structopt(short = "f", long = "format", default_value = "csv")] + file_format: String, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[structopt(short = "S", long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, +} + +const IMDB_QUERY_START_ID: usize = 1; +const IMDB_QUERY_END_ID: usize = 113; + +fn map_query_id_to_str(query_id: usize) -> &'static str { + match query_id { + // 1 + 1 => "1a", + 2 => "1b", + 3 => "1c", + 4 => "1d", + + // 2 + 5 => "2a", + 6 => "2b", + 7 => "2c", + 8 => "2d", + + // 3 + 9 => "3a", + 10 => "3b", + 11 => "3c", + + // 4 + 12 => "4a", + 13 => "4b", + 14 => "4c", + + // 5 + 15 => "5a", + 16 => "5b", + 17 => "5c", + + // 6 + 18 => "6a", + 19 => "6b", + 20 => "6c", + 21 => "6d", + 22 => "6e", + 23 => "6f", + + // 7 + 24 => "7a", + 25 => "7b", + 26 => "7c", + + // 8 + 27 => "8a", + 28 => "8b", + 29 => "8c", + 30 => "8d", + + // 9 + 31 => "9a", + 32 => "9b", + 33 => "9c", + 34 => "9d", + + // 10 + 35 => "10a", + 36 => "10b", + 37 => "10c", + + // 11 + 38 => "11a", + 39 => "11b", + 40 => "11c", + 41 => "11d", + + // 12 + 42 => "12a", + 43 => "12b", + 44 => "12c", + + // 13 + 45 => "13a", + 46 => "13b", + 47 => "13c", + 48 => "13d", + + // 14 + 49 => "14a", + 50 => "14b", + 51 => "14c", + + // 15 + 52 => "15a", + 53 => "15b", + 54 => "15c", + 55 => "15d", + + // 16 + 56 => "16a", + 57 => "16b", + 58 => "16c", + 59 => "16d", + + // 17 + 60 => "17a", + 61 => "17b", + 62 => "17c", + 63 => "17d", + 64 => "17e", + 65 => "17f", + + // 18 + 66 => "18a", + 67 => "18b", + 68 => "18c", + + // 19 + 69 => "19a", + 70 => "19b", + 71 => "19c", + 72 => "19d", + + // 20 + 73 => "20a", + 74 => "20b", + 75 => "20c", + + // 21 + 76 => "21a", + 77 => "21b", + 78 => "21c", + + // 22 + 79 => "22a", + 80 => "22b", + 81 => "22c", + 82 => "22d", + + // 23 + 83 => "23a", + 84 => "23b", + 85 => "23c", + + // 24 + 86 => "24a", + 87 => "24b", + + // 25 + 88 => "25a", + 89 => "25b", + 90 => "25c", + + // 26 + 91 => "26a", + 92 => "26b", + 93 => "26c", + + // 27 + 94 => "27a", + 95 => "27b", + 96 => "27c", + + // 28 + 97 => "28a", + 98 => "28b", + 99 => "28c", + + // 29 + 100 => "29a", + 101 => "29b", + 102 => "29c", + + // 30 + 103 => "30a", + 104 => "30b", + 105 => "30c", + + // 31 + 106 => "31a", + 107 => "31b", + 108 => "31c", + + // 32 + 109 => "32a", + 110 => "32b", + + // 33 + 111 => "33a", + 112 => "33b", + 113 => "33c", + + // Fallback for unknown query_id + _ => "unknown", + } +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => IMDB_QUERY_START_ID..=IMDB_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id).await?; + for iter in query_run { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query(&self, query_id: usize) -> Result> { + let mut config = self + .common + .config() + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + config + .options_mut() + .execution + .parquet + .schema_force_view_types = self.common.force_view_types; + let ctx = SessionContext::new_with_config(config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_id_str = map_query_id_to_str(query_id); + let sql = &get_query_sql(query_id_str)?; + + let mut result = vec![]; + + for query in sql { + result = self.execute_query(&ctx, query).await?; + } + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in IMDB_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let table_format = self.file_format.as_str(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let (format, path, extension): (Arc, String, &'static str) = + match table_format { + // dbgen creates .tbl ('|' delimited) files without header + "tbl" => { + let path = format!("{path}/{table}.tbl"); + + let format = CsvFormat::default() + .with_delimiter(b'|') + .with_has_header(false); + + (Arc::new(format), path, ".tbl") + } + "csv" => { + let path = format!("{path}/{table}.csv"); + let format = CsvFormat::default() + .with_delimiter(b',') + .with_escape(Some(b'\\')) + .with_has_header(false); + + (Arc::new(format), path, DEFAULT_CSV_EXTENSION) + } + "parquet" => { + let path = format!("{path}/{table}.parquet"); + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) + } + other => { + unimplemented!("Invalid file format '{}'", other); + } + }; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = match table_format { + "parquet" => config.with_schema(Arc::new(get_imdb_table_schema(table))), + "csv" => config.with_schema(Arc::new(get_imdb_table_schema(table))), + _ => unreachable!(), + }; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use std::path::Path; + + use super::*; + + use crate::util::CommonOpt; + use datafusion::common::exec_err; + use datafusion::error::Result; + use datafusion_proto::bytes::{ + logical_plan_from_bytes, logical_plan_to_bytes, physical_plan_from_bytes, + physical_plan_to_bytes, + }; + + fn get_imdb_data_path() -> Result { + let path = + std::env::var("IMDB_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set IMDB_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + async fn round_trip_logical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_imdb_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + force_view_types: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "parquet".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(map_query_id_to_str(query))?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", plan.display_indent()); + let plan2_formatted = format!("{}", plan2.display_indent()); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + async fn round_trip_physical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_imdb_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + force_view_types: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "parquet".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(map_query_id_to_str(query))?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.create_physical_plan().await?; + let bytes = physical_plan_to_bytes(plan.clone())?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); + let plan2_formatted = + format!("{}", displayable(plan2.as_ref()).indent(false)); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + macro_rules! test_round_trip_logical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_logical_plan($query).await + } + }; + } + + macro_rules! test_round_trip_physical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_physical_plan($query).await + } + }; + } + + // logical plan tests + test_round_trip_logical!(round_trip_logical_plan_1a, 1); + test_round_trip_logical!(round_trip_logical_plan_1b, 2); + test_round_trip_logical!(round_trip_logical_plan_1c, 3); + test_round_trip_logical!(round_trip_logical_plan_1d, 4); + test_round_trip_logical!(round_trip_logical_plan_2a, 5); + test_round_trip_logical!(round_trip_logical_plan_2b, 6); + test_round_trip_logical!(round_trip_logical_plan_2c, 7); + test_round_trip_logical!(round_trip_logical_plan_2d, 8); + test_round_trip_logical!(round_trip_logical_plan_3a, 9); + test_round_trip_logical!(round_trip_logical_plan_3b, 10); + test_round_trip_logical!(round_trip_logical_plan_3c, 11); + test_round_trip_logical!(round_trip_logical_plan_4a, 12); + test_round_trip_logical!(round_trip_logical_plan_4b, 13); + test_round_trip_logical!(round_trip_logical_plan_4c, 14); + test_round_trip_logical!(round_trip_logical_plan_5a, 15); + test_round_trip_logical!(round_trip_logical_plan_5b, 16); + test_round_trip_logical!(round_trip_logical_plan_5c, 17); + test_round_trip_logical!(round_trip_logical_plan_6a, 18); + test_round_trip_logical!(round_trip_logical_plan_6b, 19); + test_round_trip_logical!(round_trip_logical_plan_6c, 20); + test_round_trip_logical!(round_trip_logical_plan_6d, 21); + test_round_trip_logical!(round_trip_logical_plan_6e, 22); + test_round_trip_logical!(round_trip_logical_plan_6f, 23); + test_round_trip_logical!(round_trip_logical_plan_7a, 24); + test_round_trip_logical!(round_trip_logical_plan_7b, 25); + test_round_trip_logical!(round_trip_logical_plan_7c, 26); + test_round_trip_logical!(round_trip_logical_plan_8a, 27); + test_round_trip_logical!(round_trip_logical_plan_8b, 28); + test_round_trip_logical!(round_trip_logical_plan_8c, 29); + test_round_trip_logical!(round_trip_logical_plan_8d, 30); + test_round_trip_logical!(round_trip_logical_plan_9a, 31); + test_round_trip_logical!(round_trip_logical_plan_9b, 32); + test_round_trip_logical!(round_trip_logical_plan_9c, 33); + test_round_trip_logical!(round_trip_logical_plan_9d, 34); + test_round_trip_logical!(round_trip_logical_plan_10a, 35); + test_round_trip_logical!(round_trip_logical_plan_10b, 36); + test_round_trip_logical!(round_trip_logical_plan_10c, 37); + test_round_trip_logical!(round_trip_logical_plan_11a, 38); + test_round_trip_logical!(round_trip_logical_plan_11b, 39); + test_round_trip_logical!(round_trip_logical_plan_11c, 40); + test_round_trip_logical!(round_trip_logical_plan_11d, 41); + test_round_trip_logical!(round_trip_logical_plan_12a, 42); + test_round_trip_logical!(round_trip_logical_plan_12b, 43); + test_round_trip_logical!(round_trip_logical_plan_12c, 44); + test_round_trip_logical!(round_trip_logical_plan_13a, 45); + test_round_trip_logical!(round_trip_logical_plan_13b, 46); + test_round_trip_logical!(round_trip_logical_plan_13c, 47); + test_round_trip_logical!(round_trip_logical_plan_13d, 48); + test_round_trip_logical!(round_trip_logical_plan_14a, 49); + test_round_trip_logical!(round_trip_logical_plan_14b, 50); + test_round_trip_logical!(round_trip_logical_plan_14c, 51); + test_round_trip_logical!(round_trip_logical_plan_15a, 52); + test_round_trip_logical!(round_trip_logical_plan_15b, 53); + test_round_trip_logical!(round_trip_logical_plan_15c, 54); + test_round_trip_logical!(round_trip_logical_plan_15d, 55); + test_round_trip_logical!(round_trip_logical_plan_16a, 56); + test_round_trip_logical!(round_trip_logical_plan_16b, 57); + test_round_trip_logical!(round_trip_logical_plan_16c, 58); + test_round_trip_logical!(round_trip_logical_plan_16d, 59); + test_round_trip_logical!(round_trip_logical_plan_17a, 60); + test_round_trip_logical!(round_trip_logical_plan_17b, 61); + test_round_trip_logical!(round_trip_logical_plan_17c, 62); + test_round_trip_logical!(round_trip_logical_plan_17d, 63); + test_round_trip_logical!(round_trip_logical_plan_17e, 64); + test_round_trip_logical!(round_trip_logical_plan_17f, 65); + test_round_trip_logical!(round_trip_logical_plan_18a, 66); + test_round_trip_logical!(round_trip_logical_plan_18b, 67); + test_round_trip_logical!(round_trip_logical_plan_18c, 68); + test_round_trip_logical!(round_trip_logical_plan_19a, 69); + test_round_trip_logical!(round_trip_logical_plan_19b, 70); + test_round_trip_logical!(round_trip_logical_plan_19c, 71); + test_round_trip_logical!(round_trip_logical_plan_19d, 72); + test_round_trip_logical!(round_trip_logical_plan_20a, 73); + test_round_trip_logical!(round_trip_logical_plan_20b, 74); + test_round_trip_logical!(round_trip_logical_plan_20c, 75); + test_round_trip_logical!(round_trip_logical_plan_21a, 76); + test_round_trip_logical!(round_trip_logical_plan_21b, 77); + test_round_trip_logical!(round_trip_logical_plan_21c, 78); + test_round_trip_logical!(round_trip_logical_plan_22a, 79); + test_round_trip_logical!(round_trip_logical_plan_22b, 80); + test_round_trip_logical!(round_trip_logical_plan_22c, 81); + test_round_trip_logical!(round_trip_logical_plan_22d, 82); + test_round_trip_logical!(round_trip_logical_plan_23a, 83); + test_round_trip_logical!(round_trip_logical_plan_23b, 84); + test_round_trip_logical!(round_trip_logical_plan_23c, 85); + test_round_trip_logical!(round_trip_logical_plan_24a, 86); + test_round_trip_logical!(round_trip_logical_plan_24b, 87); + test_round_trip_logical!(round_trip_logical_plan_25a, 88); + test_round_trip_logical!(round_trip_logical_plan_25b, 89); + test_round_trip_logical!(round_trip_logical_plan_25c, 90); + test_round_trip_logical!(round_trip_logical_plan_26a, 91); + test_round_trip_logical!(round_trip_logical_plan_26b, 92); + test_round_trip_logical!(round_trip_logical_plan_26c, 93); + test_round_trip_logical!(round_trip_logical_plan_27a, 94); + test_round_trip_logical!(round_trip_logical_plan_27b, 95); + test_round_trip_logical!(round_trip_logical_plan_27c, 96); + test_round_trip_logical!(round_trip_logical_plan_28a, 97); + test_round_trip_logical!(round_trip_logical_plan_28b, 98); + test_round_trip_logical!(round_trip_logical_plan_28c, 99); + test_round_trip_logical!(round_trip_logical_plan_29a, 100); + test_round_trip_logical!(round_trip_logical_plan_29b, 101); + test_round_trip_logical!(round_trip_logical_plan_29c, 102); + test_round_trip_logical!(round_trip_logical_plan_30a, 103); + test_round_trip_logical!(round_trip_logical_plan_30b, 104); + test_round_trip_logical!(round_trip_logical_plan_30c, 105); + test_round_trip_logical!(round_trip_logical_plan_31a, 106); + test_round_trip_logical!(round_trip_logical_plan_31b, 107); + test_round_trip_logical!(round_trip_logical_plan_31c, 108); + test_round_trip_logical!(round_trip_logical_plan_32a, 109); + test_round_trip_logical!(round_trip_logical_plan_32b, 110); + test_round_trip_logical!(round_trip_logical_plan_33a, 111); + test_round_trip_logical!(round_trip_logical_plan_33b, 112); + test_round_trip_logical!(round_trip_logical_plan_33c, 113); + + // physical plan tests + test_round_trip_physical!(round_trip_physical_plan_1a, 1); + test_round_trip_physical!(round_trip_physical_plan_1b, 2); + test_round_trip_physical!(round_trip_physical_plan_1c, 3); + test_round_trip_physical!(round_trip_physical_plan_1d, 4); + test_round_trip_physical!(round_trip_physical_plan_2a, 5); + test_round_trip_physical!(round_trip_physical_plan_2b, 6); + test_round_trip_physical!(round_trip_physical_plan_2c, 7); + test_round_trip_physical!(round_trip_physical_plan_2d, 8); + test_round_trip_physical!(round_trip_physical_plan_3a, 9); + test_round_trip_physical!(round_trip_physical_plan_3b, 10); + test_round_trip_physical!(round_trip_physical_plan_3c, 11); + test_round_trip_physical!(round_trip_physical_plan_4a, 12); + test_round_trip_physical!(round_trip_physical_plan_4b, 13); + test_round_trip_physical!(round_trip_physical_plan_4c, 14); + test_round_trip_physical!(round_trip_physical_plan_5a, 15); + test_round_trip_physical!(round_trip_physical_plan_5b, 16); + test_round_trip_physical!(round_trip_physical_plan_5c, 17); + test_round_trip_physical!(round_trip_physical_plan_6a, 18); + test_round_trip_physical!(round_trip_physical_plan_6b, 19); + test_round_trip_physical!(round_trip_physical_plan_6c, 20); + test_round_trip_physical!(round_trip_physical_plan_6d, 21); + test_round_trip_physical!(round_trip_physical_plan_6e, 22); + test_round_trip_physical!(round_trip_physical_plan_6f, 23); + test_round_trip_physical!(round_trip_physical_plan_7a, 24); + test_round_trip_physical!(round_trip_physical_plan_7b, 25); + test_round_trip_physical!(round_trip_physical_plan_7c, 26); + test_round_trip_physical!(round_trip_physical_plan_8a, 27); + test_round_trip_physical!(round_trip_physical_plan_8b, 28); + test_round_trip_physical!(round_trip_physical_plan_8c, 29); + test_round_trip_physical!(round_trip_physical_plan_8d, 30); + test_round_trip_physical!(round_trip_physical_plan_9a, 31); + test_round_trip_physical!(round_trip_physical_plan_9b, 32); + test_round_trip_physical!(round_trip_physical_plan_9c, 33); + test_round_trip_physical!(round_trip_physical_plan_9d, 34); + test_round_trip_physical!(round_trip_physical_plan_10a, 35); + test_round_trip_physical!(round_trip_physical_plan_10b, 36); + test_round_trip_physical!(round_trip_physical_plan_10c, 37); + test_round_trip_physical!(round_trip_physical_plan_11a, 38); + test_round_trip_physical!(round_trip_physical_plan_11b, 39); + test_round_trip_physical!(round_trip_physical_plan_11c, 40); + test_round_trip_physical!(round_trip_physical_plan_11d, 41); + test_round_trip_physical!(round_trip_physical_plan_12a, 42); + test_round_trip_physical!(round_trip_physical_plan_12b, 43); + test_round_trip_physical!(round_trip_physical_plan_12c, 44); + test_round_trip_physical!(round_trip_physical_plan_13a, 45); + test_round_trip_physical!(round_trip_physical_plan_13b, 46); + test_round_trip_physical!(round_trip_physical_plan_13c, 47); + test_round_trip_physical!(round_trip_physical_plan_13d, 48); + test_round_trip_physical!(round_trip_physical_plan_14a, 49); + test_round_trip_physical!(round_trip_physical_plan_14b, 50); + test_round_trip_physical!(round_trip_physical_plan_14c, 51); + test_round_trip_physical!(round_trip_physical_plan_15a, 52); + test_round_trip_physical!(round_trip_physical_plan_15b, 53); + test_round_trip_physical!(round_trip_physical_plan_15c, 54); + test_round_trip_physical!(round_trip_physical_plan_15d, 55); + test_round_trip_physical!(round_trip_physical_plan_16a, 56); + test_round_trip_physical!(round_trip_physical_plan_16b, 57); + test_round_trip_physical!(round_trip_physical_plan_16c, 58); + test_round_trip_physical!(round_trip_physical_plan_16d, 59); + test_round_trip_physical!(round_trip_physical_plan_17a, 60); + test_round_trip_physical!(round_trip_physical_plan_17b, 61); + test_round_trip_physical!(round_trip_physical_plan_17c, 62); + test_round_trip_physical!(round_trip_physical_plan_17d, 63); + test_round_trip_physical!(round_trip_physical_plan_17e, 64); + test_round_trip_physical!(round_trip_physical_plan_17f, 65); + test_round_trip_physical!(round_trip_physical_plan_18a, 66); + test_round_trip_physical!(round_trip_physical_plan_18b, 67); + test_round_trip_physical!(round_trip_physical_plan_18c, 68); + test_round_trip_physical!(round_trip_physical_plan_19a, 69); + test_round_trip_physical!(round_trip_physical_plan_19b, 70); + test_round_trip_physical!(round_trip_physical_plan_19c, 71); + test_round_trip_physical!(round_trip_physical_plan_19d, 72); + test_round_trip_physical!(round_trip_physical_plan_20a, 73); + test_round_trip_physical!(round_trip_physical_plan_20b, 74); + test_round_trip_physical!(round_trip_physical_plan_20c, 75); + test_round_trip_physical!(round_trip_physical_plan_21a, 76); + test_round_trip_physical!(round_trip_physical_plan_21b, 77); + test_round_trip_physical!(round_trip_physical_plan_21c, 78); + test_round_trip_physical!(round_trip_physical_plan_22a, 79); + test_round_trip_physical!(round_trip_physical_plan_22b, 80); + test_round_trip_physical!(round_trip_physical_plan_22c, 81); + test_round_trip_physical!(round_trip_physical_plan_22d, 82); + test_round_trip_physical!(round_trip_physical_plan_23a, 83); + test_round_trip_physical!(round_trip_physical_plan_23b, 84); + test_round_trip_physical!(round_trip_physical_plan_23c, 85); + test_round_trip_physical!(round_trip_physical_plan_24a, 86); + test_round_trip_physical!(round_trip_physical_plan_24b, 87); + test_round_trip_physical!(round_trip_physical_plan_25a, 88); + test_round_trip_physical!(round_trip_physical_plan_25b, 89); + test_round_trip_physical!(round_trip_physical_plan_25c, 90); + test_round_trip_physical!(round_trip_physical_plan_26a, 91); + test_round_trip_physical!(round_trip_physical_plan_26b, 92); + test_round_trip_physical!(round_trip_physical_plan_26c, 93); + test_round_trip_physical!(round_trip_physical_plan_27a, 94); + test_round_trip_physical!(round_trip_physical_plan_27b, 95); + test_round_trip_physical!(round_trip_physical_plan_27c, 96); + test_round_trip_physical!(round_trip_physical_plan_28a, 97); + test_round_trip_physical!(round_trip_physical_plan_28b, 98); + test_round_trip_physical!(round_trip_physical_plan_28c, 99); + test_round_trip_physical!(round_trip_physical_plan_29a, 100); + test_round_trip_physical!(round_trip_physical_plan_29b, 101); + test_round_trip_physical!(round_trip_physical_plan_29c, 102); + test_round_trip_physical!(round_trip_physical_plan_30a, 103); + test_round_trip_physical!(round_trip_physical_plan_30b, 104); + test_round_trip_physical!(round_trip_physical_plan_30c, 105); + test_round_trip_physical!(round_trip_physical_plan_31a, 106); + test_round_trip_physical!(round_trip_physical_plan_31b, 107); + test_round_trip_physical!(round_trip_physical_plan_31c, 108); + test_round_trip_physical!(round_trip_physical_plan_32a, 109); + test_round_trip_physical!(round_trip_physical_plan_32b, 110); + test_round_trip_physical!(round_trip_physical_plan_33a, 111); + test_round_trip_physical!(round_trip_physical_plan_33b, 112); + test_round_trip_physical!(round_trip_physical_plan_33c, 113); +} diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 52d81ca91816a..02410e0cfa01e 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -21,5 +21,4 @@ pub mod imdb; pub mod parquet_filter; pub mod sort; pub mod tpch; -mod util; -pub use util::*; +pub mod util; diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index 5c98a2f8be3de..34103af0ffd21 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -17,7 +17,7 @@ use std::path::PathBuf; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 19eec2949ef61..247727e1b4840 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -18,7 +18,7 @@ use std::path::PathBuf; use std::sync::Arc; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 1a1f51f700651..e316a66e1c600 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::{ get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, }; -use crate::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; diff --git a/ci/scripts/retry b/ci/scripts/retry new file mode 100755 index 0000000000000..0569dea58c94a --- /dev/null +++ b/ci/scripts/retry @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -euo pipefail + +x() { + echo "+ $*" >&2 + "$@" +} + +max_retry_time_seconds=$(( 3 * 60 )) +retry_delay_seconds=10 + +END=$(( $(date +%s) + ${max_retry_time_seconds} )) + +while (( $(date +%s) < $END )); do + x "$@" && exit 0 + sleep "${retry_delay_seconds}" +done + +echo "$0: retrying [$*] timed out" >&2 +exit 1 diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fbe7d5c04b9bf..15ba8c3d5f26e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -84,9 +84,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" dependencies = [ "anstyle", "anstyle-parse", @@ -99,36 +99,36 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -173,9 +173,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45aef0d9cf9a039bf6cd1acc451b137aca819977b0928dece52bd92811b640ba" +checksum = "4caf25cdc4a985f91df42ed9e9308e1adbcd341a31a72605c697033fcef163e3" dependencies = [ "arrow-arith", "arrow-array", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03675e42d1560790f3524800e41403b40d0da1c793fe9528929fde06d8c7649a" +checksum = "91f2dfd1a7ec0aca967dfaa616096aec49779adc8eccec005e2f5e4111b1192a" dependencies = [ "arrow-array", "arrow-buffer", @@ -209,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd2bf348cf9f02a5975c5962c7fa6dee107a2009a7b41ac5fb1a027e12dc033f" +checksum = "d39387ca628be747394890a6e47f138ceac1aa912eab64f02519fed24b637af8" dependencies = [ "ahash", "arrow-buffer", @@ -220,15 +220,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown", + "hashbrown 0.14.5", "num", ] [[package]] name = "arrow-buffer" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3092e37715f168976012ce52273c3989b5793b0db5f06cbaa246be25e5f0924d" +checksum = "9e51e05228852ffe3eb391ce7178a0f97d2cf80cc6ef91d3c4a6b3cb688049ec" dependencies = [ "bytes", "half", @@ -237,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ce1018bb710d502f9db06af026ed3561552e493e989a79d0d0f5d9cf267a785" +checksum = "d09aea56ec9fa267f3f3f6cdab67d8a9974cbba90b3aa38c8fe9d0bb071bd8c1" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd178575f45624d045e4ebee714e246a05d9652e41363ee3f57ec18cca97f740" +checksum = "c07b5232be87d115fde73e32f2ca7f1b353bff1b44ac422d3c6fc6ae38f11f0d" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4ac0c4ee79150afe067dc4857154b3ee9c1cd52b5f40d59a77306d0ed18d65" +checksum = "b98ae0af50890b494cebd7d6b04b35e896205c1d1df7b29a6272c5d0d0249ef5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb307482348a1267f91b0912e962cd53440e5de0f7fb24c5f7b10da70b38c94a" +checksum = "0ed91bdeaff5a1c00d28d8f73466bcb64d32bbd7093b5a30156b4b9f4dba3eee" dependencies = [ "arrow-array", "arrow-buffer", @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d24805ba326758effdd6f2cbdd482fcfab749544f21b134701add25b33f474e6" +checksum = "0471f51260a5309307e5d409c9dc70aede1cd9cf1d4ff0f0a1e8e1a2dd0e0d3c" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644046c479d80ae8ed02a7f1e1399072ea344ca6a7b0e293ab2d5d9ed924aa3b" +checksum = "2883d7035e0b600fb4c30ce1e50e66e53d8656aa729f2bfa4b51d359cf3ded52" dependencies = [ "arrow-array", "arrow-buffer", @@ -339,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a29791f8eb13b340ce35525b723f5f0df17ecb955599e11f65c2a94ab34e2efb" +checksum = "552907e8e587a6fde4f8843fd7a27a576a260f65dab6c065741ea79f633fc5be" dependencies = [ "ahash", "arrow-array", @@ -353,15 +353,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85320a3a2facf2b2822b57aa9d6d9d55edb8aee0b6b5d3b8df158e503d10858" +checksum = "539ada65246b949bd99ffa0881a9a15a4a529448af1a07a9838dd78617dafab1" [[package]] name = "arrow-select" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cc7e6b582e23855fd1625ce46e51647aa440c20ea2e71b1d748e0839dd73cba" +checksum = "6259e566b752da6dceab91766ed8b2e67bf6270eb9ad8a6e07a33c1bede2b125" dependencies = [ "ahash", "arrow-array", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0775b6567c66e56ded19b87a954b6b1beffbdd784ef95a3a2b03f59570c1d230" +checksum = "f3179ccbd18ebf04277a095ba7321b93fd1f774f18816bd5f6b3ce2f594edb6c" dependencies = [ "arrow-array", "arrow-buffer", @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.12" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" dependencies = [ "bzip2", "flate2", @@ -424,9 +424,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -450,15 +450,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.6" +version = "1.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848d7b9b605720989929279fa644ce8f244d0ce3146fcca5b70e4eb7b3c020fc" +checksum = "2d6448cfb224dd6a9b9ac734f58622dd0d4751f3589f3b777345745f46b2eb14" dependencies = [ "aws-credential-types", "aws-runtime", @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.43.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a9d27ed1c12b1140c47daf1bc541606c43fdafd918c4797d520db0043ceef2" +checksum = "a8776850becacbd3a82a4737a9375ddb5c6832a51379f24443a98e61513f852c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -545,9 +545,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.44.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44514a6ca967686cde1e2a1b81df6ef1883d0e3e570da8d8bc5c491dcb6fc29b" +checksum = "0007b5b8004547133319b6c4e87193eee2a0bcb3e4c18c75d09febe9dab7b383" dependencies = [ "aws-credential-types", "aws-runtime", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.43.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd7a4d279762a35b9df97209f6808b95d4fe78547fe2316b4d200a0283960c5a" +checksum = "9fffaa356e7f1c725908b75136d53207fa714e348f365671df14e95a60530ad3" dependencies = [ "aws-credential-types", "aws-runtime", @@ -590,9 +590,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -663,9 +663,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -678,7 +678,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -707,9 +707,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.6" +version = "1.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03701449087215b5369c7ea17fef0dd5d24cb93439ec5af0c7615f58c3f22605" +checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" dependencies = [ "base64-simd", "bytes", @@ -836,9 +836,9 @@ dependencies = [ [[package]] name = "brotli" -version = "6.0.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -880,9 +880,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "bytes-utils" @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.21" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", @@ -938,6 +938,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -953,9 +959,9 @@ dependencies = [ [[package]] name = "chrono-tz" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +checksum = "cd6dd8046d00723a59a2f8c5f295c515b9bb9a331ee4f8f3d4dd49e428acd3b6" dependencies = [ "chrono", "chrono-tz-build", @@ -964,20 +970,19 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] [[package]] name = "clap" -version = "4.5.17" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -985,9 +990,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.17" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -997,9 +1002,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1024,9 +1029,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" @@ -1163,9 +1168,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "dashmap" @@ -1175,7 +1180,7 @@ checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", - "hashbrown", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -1183,7 +1188,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1216,7 +1221,7 @@ dependencies = [ "futures", "glob", "half", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1240,7 +1245,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow-schema", "async-trait", @@ -1253,7 +1258,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "assert_cmd", @@ -1283,7 +1288,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1293,7 +1298,8 @@ dependencies = [ "arrow-schema", "chrono", "half", - "hashbrown", + "hashbrown 0.14.5", + "indexmap", "instant", "libc", "num_cpus", @@ -1306,7 +1312,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "42.0.0" +version = "42.1.0" dependencies = [ "log", "tokio", @@ -1314,7 +1320,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "chrono", @@ -1322,7 +1328,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown", + "hashbrown 0.14.5", "log", "object_store", "parking_lot", @@ -1333,7 +1339,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1345,6 +1351,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", + "indexmap", "paste", "serde_json", "sqlparser", @@ -1354,16 +1361,17 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "datafusion-common", + "itertools", "paste", ] [[package]] name = "datafusion-functions" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -1374,7 +1382,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "hashbrown", + "hashbrown 0.14.5", "hex", "itertools", "log", @@ -1388,7 +1396,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1400,14 +1408,14 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "half", + "indexmap", "log", "paste", - "sqlparser", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1419,7 +1427,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", @@ -1440,25 +1448,28 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-functions-window-common", + "datafusion-physical-expr", "datafusion-physical-expr-common", "log", + "paste", ] [[package]] name = "datafusion-functions-window-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", + "datafusion-physical-expr-common", ] [[package]] name = "datafusion-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "async-trait", @@ -1466,7 +1477,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1476,7 +1487,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1485,17 +1496,14 @@ dependencies = [ "arrow-ord", "arrow-schema", "arrow-string", - "base64 0.22.1", "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown", - "hex", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1506,23 +1514,25 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", "datafusion-common", "datafusion-expr-common", - "hashbrown", + "hashbrown 0.14.5", "rand", ] [[package]] name = "datafusion-physical-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ + "arrow", "arrow-schema", "datafusion-common", "datafusion-execution", + "datafusion-expr-common", "datafusion-physical-expr", "datafusion-physical-plan", "itertools", @@ -1530,7 +1540,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1544,14 +1554,13 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1564,13 +1573,14 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", "arrow-schema", "datafusion-common", "datafusion-expr", + "indexmap", "log", "regex", "sqlparser", @@ -1722,9 +1732,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.33" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", "miniz_oxide", @@ -1756,9 +1766,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1771,9 +1781,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1781,15 +1791,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1798,15 +1808,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -1815,15 +1825,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -1833,9 +1843,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1872,9 +1882,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1941,6 +1951,12 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" + [[package]] name = "heck" version = "0.4.1" @@ -2041,9 +2057,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -2059,9 +2075,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -2083,9 +2099,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -2109,7 +2125,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2125,9 +2141,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", - "rustls 0.23.13", + "rustls 0.23.16", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2137,20 +2153,19 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.8" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", - "tower", "tower-service", "tracing", ] @@ -2190,12 +2205,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.0", ] [[package]] @@ -2218,9 +2233,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is_terminal_polyfill" @@ -2254,9 +2269,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2269,9 +2284,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -2282,9 +2297,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -2293,9 +2308,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" dependencies = [ "lexical-util", "static_assertions", @@ -2303,18 +2318,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" dependencies = [ "lexical-util", "lexical-write-integer", @@ -2323,9 +2338,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" dependencies = [ "lexical-util", "static_assertions", @@ -2333,9 +2348,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.158" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" @@ -2357,15 +2372,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" dependencies = [ "core2", - "hashbrown", + "hashbrown 0.14.5", "rle-decode-fast", ] [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libmimalloc-sys" @@ -2498,7 +2513,7 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", "libc", ] @@ -2600,18 +2615,18 @@ dependencies = [ [[package]] name = "object" -version = "0.36.4" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25a0c4b3a0e31f8b66f71ad8064521efa773910196e2cde791436f13409f3b45" +checksum = "6eb4c22c6154a1e759d7099f9ffad7cc5ef8245f9efbab4a41b92623079c82f3" dependencies = [ "async-trait", "base64 0.22.1", @@ -2619,7 +2634,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.0", "itertools", "md-5", "parking_lot", @@ -2628,7 +2643,7 @@ dependencies = [ "rand", "reqwest", "ring", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "serde", "serde_json", "snafu", @@ -2640,9 +2655,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl-probe" @@ -2696,9 +2711,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0fbf928021131daaa57d334ca8e3904fe9ae22f73c56244fc7db9b04eedc3d8" +checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" dependencies = [ "ahash", "arrow-array", @@ -2715,7 +2730,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown", + "hashbrown 0.14.5", "lz4_flex", "num", "num-bigint", @@ -2799,31 +2814,11 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-project" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2833,9 +2828,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "powerfmt" @@ -2893,9 +2888,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -2908,9 +2903,9 @@ checksum = "b76f1009795ca44bb5aaae8fd3f18953e209259c33d9b059b1f53d58ab7511db" [[package]] name = "quick-xml" -version = "0.36.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96a05e2e8efddfa51a84ca47cec303fac86c8541b686d37cac5efc0e094417bc" +checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" dependencies = [ "memchr", "serde", @@ -2927,7 +2922,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.13", + "rustls 0.23.16", "socket2", "thiserror", "tokio", @@ -2944,7 +2939,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.13", + "rustls 0.23.16", "slab", "thiserror", "tinyvec", @@ -2953,10 +2948,11 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "e346e016eacfff12233c243718197ca12f148c84e1e84268a896699b41c71780" dependencies = [ + "cfg_aliases 0.2.1", "libc", "once_cell", "socket2", @@ -3015,9 +3011,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.4" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ "bitflags 2.6.0", ] @@ -3035,9 +3031,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -3047,9 +3043,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -3064,9 +3060,9 @@ checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -3076,9 +3072,9 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.7" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", @@ -3088,7 +3084,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3099,9 +3095,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.13", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.1.3", + "rustls 0.23.16", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", @@ -3193,9 +3189,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags 2.6.0", "errno", @@ -3218,9 +3214,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "once_cell", "ring", @@ -3242,19 +3238,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-native-certs" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.1.3", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-native-certs" version = "0.8.0" @@ -3262,7 +3245,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", "security-framework", @@ -3279,19 +3262,18 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.3" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.1", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -3316,9 +3298,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "rustyline" @@ -3359,9 +3341,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ "windows-sys 0.59.0", ] @@ -3397,9 +3379,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.1" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3419,18 +3401,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -3439,9 +3421,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -3510,18 +3492,18 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snafu" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b835cb902660db3415a672d862905e791e54d306c6e8189168c7f3d9ae1c79d" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d1e02fca405f6280643174a50c942219f0bbf4dbf7d480f1dd864d6f211ae5" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -3633,9 +3615,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.77" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -3653,9 +3635,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", @@ -3672,18 +3654,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", @@ -3757,9 +3739,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", @@ -3800,7 +3782,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.13", + "rustls 0.23.16", "rustls-pki-types", "tokio", ] @@ -3826,36 +3808,15 @@ checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" [[package]] name = "toml_edit" -version = "0.22.21" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "toml_datetime", "winnow", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - [[package]] name = "tower-service" version = "0.3.3" @@ -3937,9 +3898,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" @@ -3964,9 +3925,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "untrusted" @@ -3999,9 +3960,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -4055,9 +4016,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4066,9 +4027,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -4081,9 +4042,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4093,9 +4054,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4103,9 +4064,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -4116,15 +4077,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -4135,9 +4096,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -4341,9 +4302,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b86dbd2a38027..8e43526128896 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "42.0.0" +version = "42.1.0" authors = ["Apache DataFusion "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" readme = "README.md" [dependencies] @@ -39,7 +39,7 @@ aws-sdk-sts = "1.43.0" # end pin aws-sdk crates aws-credential-types = "1.2.0" clap = { version = "4.5.16", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "42.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "42.1.0", features = [ "avro", "crypto_expressions", "datetime_expressions", diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index 7adead64db57c..79c24f6baf3ef 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.78-bookworm AS builder +FROM rust:1.79-bookworm AS builder COPY . /usr/src/datafusion COPY ./datafusion /usr/src/datafusion/datafusion diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 9b9afc1c24208..ceb72dbc546bd 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -34,6 +34,7 @@ use dirs::home_dir; use parking_lot::RwLock; /// Wraps another catalog, automatically register require object stores for the file locations +#[derive(Debug)] pub struct DynamicObjectStoreCatalog { inner: Arc, state: Weak>, @@ -74,6 +75,7 @@ impl CatalogProviderList for DynamicObjectStoreCatalog { } /// Wraps another catalog provider +#[derive(Debug)] struct DynamicObjectStoreCatalogProvider { inner: Arc, state: Weak>, @@ -115,6 +117,7 @@ impl CatalogProvider for DynamicObjectStoreCatalogProvider { /// Wraps another schema provider. [DynamicObjectStoreSchemaProvider] is responsible for registering the required /// object stores for the file locations. +#[derive(Debug)] struct DynamicObjectStoreSchemaProvider { inner: Arc, state: Weak>, diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index db4242d971758..18906536691ef 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -383,7 +383,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx.register_table_options_extension_from_scheme(scheme); // Clone and modify the default table options based on the provided options - let mut table_options = ctx.session_state().default_table_options().clone(); + let mut table_options = ctx.session_state().default_table_options(); if let Some(format) = format { table_options.set_config_format(format); } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index dd56b0196dd5b..c622463de0331 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -315,6 +315,7 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option().unwrap(); let builder = @@ -540,7 +540,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) @@ -566,7 +566,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); // ensure this isn't an error @@ -594,7 +594,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; @@ -631,7 +631,7 @@ mod tests { if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let gcp_options = table_options.extensions.get::().unwrap(); let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index f430a87e190db..e2432abdc1384 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -62,6 +62,7 @@ dashmap = { workspace = true } datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-proto = { workspace = true } diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 1259f90d64496..414596bdc6787 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -193,7 +193,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -394,8 +394,8 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.prods.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + + self.prods.capacity() * size_of::() } } diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index fd1b84070cf68..1c20e292f0916 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -30,6 +30,7 @@ use datafusion_expr::function::WindowUDFFieldArgs; use datafusion_expr::{ PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -74,7 +75,10 @@ impl WindowUDFImpl for SmoothItUdf { /// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) } diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 8c2b1aad56c64..f40f1dfb5a159 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -135,6 +135,7 @@ struct DirSchemaOpts<'a> { format: Arc, } /// Schema where every file with extension `ext` in a given `dir` is a table. +#[derive(Debug)] struct DirSchema { ext: String, tables: RwLock>>, @@ -218,6 +219,7 @@ impl SchemaProvider for DirSchema { } } /// Catalog holds multiple schemas +#[derive(Debug)] struct DirCatalog { schemas: RwLock>>, } @@ -259,6 +261,7 @@ impl CatalogProvider for DirCatalog { } } /// Catalog lists holds multiple catalog providers. Each context has a single catalog list. +#[derive(Debug)] struct CustomCatalogProviderList { catalogs: RwLock>>, } diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 0f7748b133650..7440e592962b3 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -110,7 +110,7 @@ struct CustomDataSourceInner { } impl Debug for CustomDataSource { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("custom_db") } } @@ -220,7 +220,7 @@ impl CustomExec { } impl DisplayAs for CustomExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { write!(f, "CustomExec") } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 1d9b587f15b93..95168597ebaaf 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -74,10 +74,7 @@ impl FileFormat for TSVFileFormat { "tsv".to_string() } - fn get_ext_with_compression( - &self, - c: &FileCompressionType, - ) -> datafusion::error::Result { + fn get_ext_with_compression(&self, c: &FileCompressionType) -> Result { if c == &FileCompressionType::UNCOMPRESSED { Ok("tsv".to_string()) } else { @@ -154,7 +151,7 @@ impl FileFormatFactory for TSVFileFactory { &self, state: &SessionState, format_options: &std::collections::HashMap, - ) -> Result> { + ) -> Result> { let mut new_options = format_options.clone(); new_options.insert("format.delimiter".to_string(), "\t".to_string()); @@ -164,7 +161,7 @@ impl FileFormatFactory for TSVFileFactory { Ok(tsv_file_format) } - fn default(&self) -> std::sync::Arc { + fn default(&self) -> Arc { todo!() } diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index f9d1b8029f04b..cc5f43746ddfb 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -105,7 +105,7 @@ impl FlightService for FlightServiceImpl { } // add an initial FlightData message that sends schema - let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default(); + let options = arrow::ipc::writer::IpcWriteOptions::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index f57b3bf604048..b42f25437d772 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -121,7 +121,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 140fc0d3572da..ef97bf9763b0f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -131,7 +131,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index baa783fce9e42..6faa397ef60f3 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -128,6 +128,7 @@ impl TableProvider for LocalCsvTable { } } +#[derive(Debug)] struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index aedc511c62fef..52a27317e3c3d 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -70,7 +70,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } @@ -90,8 +90,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { fn simplify(&self) -> Option { // as an example for this functionality we replace UDF function // with build-in aggregate function to illustrate the use - let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, - _: &dyn SimplifyInfo| { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( avg_udaf(), // yes it is the same Avg, `BetterAvgUdaf` was just a diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 1ff629eef1966..117063df4e0d8 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -27,6 +27,7 @@ use datafusion_expr::{ expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This UDWF will show how to use the WindowUDFImpl::simplify() API #[derive(Debug, Clone)] @@ -60,14 +61,16 @@ impl WindowUDFImpl for SimplifySmoothItUdf { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { todo!() } /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. fn simplify(&self) -> Option { - let simplify = |window_function: datafusion_expr::expr::WindowFunction, - _: &dyn SimplifyInfo| { + let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { Ok(Expr::WindowFunction(WindowFunction { fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 9a2aabaa79c2e..2158b8e4b016e 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -39,7 +39,7 @@ fn total_join_count(plan: &LogicalPlan) -> usize { // We can use the TreeNode API to walk over a LogicalPlan. plan.apply(|node| { // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; } Ok(TreeNodeRecursion::Continue) @@ -89,7 +89,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { while let Some(node) = to_visit.pop() { // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { let (group_count, inputs) = count_tree(node); total += group_count; groups.push(group_count); @@ -151,7 +151,7 @@ fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { } // any join we count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; Ok(TreeNodeRecursion::Continue) } else { diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md new file mode 100644 index 0000000000000..5b201e736fdc4 --- /dev/null +++ b/datafusion/catalog/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Catalog + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 9ee94e8f1fc33..048a7f14ed378 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; pub use crate::schema::SchemaProvider; @@ -101,7 +102,7 @@ use datafusion_common::Result; /// /// [`TableProvider`]: crate::TableProvider -pub trait CatalogProvider: Sync + Send { +pub trait CatalogProvider: Debug + Sync + Send { /// Returns the catalog provider as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -152,7 +153,7 @@ pub trait CatalogProvider: Sync + Send { /// /// Please see the documentation on `CatalogProvider` for details of /// implementing a custom catalog. -pub trait CatalogProviderList: Sync + Send { +pub trait CatalogProviderList: Debug + Sync + Send { /// Returns the catalog list as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; diff --git a/datafusion/catalog/src/dynamic_file/catalog.rs b/datafusion/catalog/src/dynamic_file/catalog.rs index cd586446f82c2..ccccb9762eb4c 100644 --- a/datafusion/catalog/src/dynamic_file/catalog.rs +++ b/datafusion/catalog/src/dynamic_file/catalog.rs @@ -20,9 +20,11 @@ use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; use async_trait::async_trait; use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; /// Wrap another catalog provider list +#[derive(Debug)] pub struct DynamicFileCatalog { /// The inner catalog provider list inner: Arc, @@ -67,6 +69,7 @@ impl CatalogProviderList for DynamicFileCatalog { } /// Wraps another catalog provider +#[derive(Debug)] struct DynamicFileCatalogProvider { /// The inner catalog provider inner: Arc, @@ -114,6 +117,7 @@ impl CatalogProvider for DynamicFileCatalogProvider { /// /// The provider will try to create a table provider from the file path if the table provider /// isn't exist in the inner schema provider. +#[derive(Debug)] pub struct DynamicFileSchemaProvider { /// The inner schema provider inner: Arc, @@ -174,7 +178,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { /// [UrlTableFactory] is a factory that can create a table provider from the given url. #[async_trait] -pub trait UrlTableFactory: Sync + Send { +pub trait UrlTableFactory: Debug + Sync + Send { /// create a new table provider from the provided url async fn try_new( &self, diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 21bca9fa828dc..5b37348fd7427 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -21,6 +21,7 @@ use async_trait::async_trait; use datafusion_common::{exec_err, DataFusionError}; use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; use crate::table::TableProvider; @@ -32,7 +33,7 @@ use datafusion_common::Result; /// /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] -pub trait SchemaProvider: Sync + Send { +pub trait SchemaProvider: Debug + Sync + Send { /// Returns the owner of the Schema, default is None. This value is reported /// as part of `information_tables.schemata fn owner_name(&self) -> Option<&str> { diff --git a/datafusion/catalog/src/session.rs b/datafusion/catalog/src/session.rs index 61d9c2d8a71e5..db49529ac43f5 100644 --- a/datafusion/catalog/src/session.rs +++ b/datafusion/catalog/src/session.rs @@ -139,6 +139,7 @@ impl From<&dyn Session> for TaskContext { } type SessionRefLock = Arc>>>>; /// The state store that stores the reference of the runtime session state. +#[derive(Debug)] pub struct SessionStore { session: SessionRefLock, } diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 6c36d907acc3d..ca3a2bef882e2 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -25,6 +25,7 @@ use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{ CreateExternalTable, Expr, LogicalPlan, TableProviderFilterPushDown, TableType, }; @@ -274,7 +275,7 @@ pub trait TableProvider: Debug + Sync + Send { &self, _state: &dyn Session, _input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { not_impl_err!("Insert into not implemented for this table") } diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 1ac27b40c2194..0747672a18f6e 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -56,6 +56,7 @@ arrow-schema = { workspace = true } chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } +indexmap = { workspace = true } libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b08..15290204fbace 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -338,6 +338,12 @@ config_namespace! { /// if the source of statistics is accurate. /// We plan to make this the default in the future. pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } @@ -384,6 +390,14 @@ config_namespace! { /// and `Binary/BinaryLarge` with `BinaryView`. pub schema_force_view_types: bool, default = false + /// (reading) If true, parquet reader will read columns of + /// `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. + /// + /// Parquet files generated by some legacy writers do not correctly set + /// the UTF8 flag for strings, causing string columns to be loaded as + /// BLOB instead. + pub binary_as_string: bool, default = false + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties @@ -862,7 +876,7 @@ pub trait ConfigExtension: ExtensionOptions { } /// An object-safe API for storing arbitrary configuration -pub trait ExtensionOptions: Send + Sync + std::fmt::Debug + 'static { +pub trait ExtensionOptions: Send + Sync + fmt::Debug + 'static { /// Return `self` as [`Any`] /// /// This is needed until trait upcasting is stabilised @@ -1222,16 +1236,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - ConfigFileType::PARQUET => self.parquet.set(rem, value), - ConfigFileType::CSV => self.csv.set(rem, value), - ConfigFileType::JSON => self.json.set(rem, value), - }, + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), + } + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs new file mode 100644 index 0000000000000..ab02915858cd2 --- /dev/null +++ b/datafusion/common/src/cse.rs @@ -0,0 +1,816 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with +//! a [`CSEController`], that defines how to eliminate common subtrees from a particular +//! [`TreeNode`] tree. + +use crate::hash_utils::combine_hashes; +use crate::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, +}; +use crate::Result; +use indexmap::IndexMap; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Hashes the direct content of an [`TreeNode`] without recursing into its children. +/// +/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds +/// a deep hash of a node and its descendants during the bottom-up phase of the first +/// traversal and so avoid computing the hash of the node and then the hash of its +/// descendants separately. +/// +/// If a node doesn't have any children then the value returned by `hash_node()` is +/// similar to '.hash()`, but not necessarily returns the same value. +pub trait HashNode { + fn hash_node(&self, state: &mut H); +} + +impl HashNode for Arc { + fn hash_node(&self, state: &mut H) { + (**self).hash_node(state); + } +} + +/// Identifier that represents a [`TreeNode`] tree. +/// +/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and +/// "have no collision (as low as possible)" +#[derive(Debug, Eq, PartialEq)] +struct Identifier<'n, N> { + // Hash of `node` built up incrementally during the first, visiting traversal. + // Its value is not necessarily equal to default hash of the node. E.g. it is not + // equal to `expr.hash()` if the node is `Expr`. + hash: u64, + node: &'n N, +} + +impl Clone for Identifier<'_, N> { + fn clone(&self) -> Self { + *self + } +} +impl Copy for Identifier<'_, N> {} + +impl Hash for Identifier<'_, N> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +impl<'n, N: HashNode> Identifier<'n, N> { + fn new(node: &'n N, random_state: &RandomState) -> Self { + let mut hasher = random_state.build_hasher(); + node.hash_node(&mut hasher); + let hash = hasher.finish(); + Self { hash, node } + } + + fn combine(mut self, other: Option) -> Self { + other.map_or(self, |other_id| { + self.hash = combine_hashes(self.hash, other_id.hash); + self + }) + } +} + +/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the +/// preorder index of the nodes. +/// +/// This cache is filled by [`CSEVisitor`] during the first traversal and is +/// used by [`CSERewriter`] during the second traversal. +/// +/// The purpose of this cache is to quickly find the identifier of a node during the +/// second traversal. +/// +/// Elements in this array are added during `f_down` so the indexes represent the preorder +/// index of nodes and thus element 0 belongs to the root of the tree. +/// +/// The elements of the array are tuples that contain: +/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start +/// from 0. +/// - The optional [`Identifier`] of the node. If none the node should not be considered +/// for CSE. +/// +/// # Example +/// An expression tree like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))), +/// (1, Some(Identifier(hash_of("a"), &"a"))), +/// (0, Some(Identifier(hash_of("b"), &"b"))) +/// ] +/// ``` +type IdArray<'n, N> = Vec<(usize, Option>)>; + +#[derive(PartialEq, Eq)] +/// How many times a node is evaluated. A node can be considered common if evaluated +/// surely at least 2 times or surely only once but also conditionally. +enum NodeEvaluation { + SurelyOnce, + ConditionallyAtLeastOnce, + Common, +} + +/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. +type NodeStats<'n, N> = HashMap, NodeEvaluation>; + +/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, +/// extracted during the second, rewriting traversal. +type CommonNodes<'n, N> = IndexMap, (N, String)>; + +type ChildrenList = (Vec, Vec); + +/// The [`TreeNode`] specific definition of elimination. +pub trait CSEController { + /// The type of the tree nodes. + type Node; + + /// Splits the children to normal and conditionally evaluated ones or returns `None` + /// if all are always evaluated. + fn conditional_children(node: &Self::Node) -> Option>; + + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. + // Validity is propagated up which means no subtree can be eliminated that contains + // an invalid node. + // (E.g. volatile expressions are not valid and subtrees containing such a node can't + // be extracted.) + fn is_valid(node: &Self::Node) -> bool; + + // Returns true if a node should be ignored during CSE. Contrary to validity of a node, + // it is not propagated up. + fn is_ignored(&self, node: &Self::Node) -> bool; + + // Generates a new name for the extracted subtree. + fn generate_alias(&self) -> String; + + // Replaces a node to the generated alias. + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + + // A helper method called on each node during top-down traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node during bottom-up traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common +/// subtrees. +#[derive(Debug)] +pub enum FoundCommonNodes { + /// No common [`TreeNode`]s were found + No { original_nodes_list: Vec> }, + + /// Common [`TreeNode`]s were found + Yes { + /// extracted common [`TreeNode`] + common_nodes: Vec<(N, String)>, + + /// new [`TreeNode`]s with common subtrees replaced + new_nodes_list: Vec>, + + /// original [`TreeNode`]s + original_nodes_list: Vec>, + }, +} + +/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees. +/// +/// An identifier contains information of the [`TreeNode`] itself and its subtrees. +/// This visitor implementation use a stack `visit_stack` to track traversal, which +/// lets us know when a subtree's visiting is finished. When `pre_visit` is called +/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack. +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem` +/// before the first `EnterMark` is considered to be sub-tree of the leaving node. +/// +/// This visitor also records identifier in `id_array`. Makes the following traverse +/// pass can get the identifier of a node without recalculate it. We assign each node +/// in the tree a series number, start from 1, maintained by `series_number`. +/// Series number represents the order we left (`f_up()`) a node. Has the property +/// that child node's series number always smaller than parent's. While `id_array` is +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to +/// get the index of `id_array` for each node. +/// +/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier +/// because they should not be recognized as common subtree. +struct CSEVisitor<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a mut NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a mut IdArray<'n, N>, + + /// inner states + visit_stack: Vec>, + + /// preorder index, start from 0. + down_index: usize, + + /// postorder index, start from 0. + up_index: usize, + + /// a [`RandomState`] to generate hashes during the first traversal + random_state: &'a RandomState, + + /// a flag to indicate that common [`TreeNode`]s found + found_common: bool, + + /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`] + /// might not be executed depending on the runtime values of other [`TreeNode`]s, and + /// thus can not be extracted as a common [`TreeNode`]. + conditional: bool, + + controller: &'a C, +} + +/// Record item that used when traversing a [`TreeNode`] tree. +enum VisitRecord<'n, N> { + /// Marks the beginning of [`TreeNode`]. It contains: + /// - The post-order index assigned during the first, visiting traversal. + EnterMark(usize), + + /// Marks an accumulated subtree. It contains: + /// - The accumulated identifier of a subtree. + /// - A accumulated boolean flag if the subtree is valid for CSE. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + NodeItem(Identifier<'n, N>, bool), +} + +impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { + /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before + /// it. Returns a tuple that contains: + /// - The pre-order index of the [`TreeNode`] we marked. + /// - The accumulated identifier of the children of the marked [`TreeNode`]. + /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all + /// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a + /// common [`TreeNode`] from its children POV). + /// (E.g. if any of the children of the marked expression is not valid (e.g. is + /// volatile) then the expression is also not valid, so we can propagate this + /// information up from children to parents via `visit_stack` during the first, + /// visiting traversal and no need to test the expression's validity beforehand with + /// an extra traversal). + fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { + let mut node_id = None; + let mut is_valid = true; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::EnterMark(down_index) => { + return (down_index, node_id, is_valid); + } + VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { + node_id = Some(sub_node_id.combine(node_id)); + is_valid &= sub_node_is_valid; + } + } + } + unreachable!("EnterMark should paired with NodeItem"); + } +} + +impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> + for CSEVisitor<'_, 'n, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + self.id_array.push((0, None)); + self.visit_stack + .push(VisitRecord::EnterMark(self.down_index)); + self.down_index += 1; + + // If a node can short-circuit then some of its children might not be executed so + // count the occurrence either normal or conditional. + Ok(if self.conditional { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + TreeNodeRecursion::Continue + } else { + // If we are already in a node that can short-circuit then start new + // traversals on its normal conditional children. + match C::conditional_children(node) { + Some((normal, conditional)) => { + normal + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = true; + conditional + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = false; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit node continue the traversal. + _ => TreeNodeRecursion::Continue, + } + }) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark(); + + let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); + let is_valid = C::is_valid(node) && sub_node_is_valid; + + self.id_array[down_index].0 = self.up_index; + if is_valid && !self.controller.is_ignored(node) { + self.id_array[down_index].1 = Some(node_id); + self.node_stats + .entry(node_id) + .and_modify(|evaluation| { + if *evaluation == NodeEvaluation::SurelyOnce + || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce + && !self.conditional + { + *evaluation = NodeEvaluation::Common; + self.found_common = true; + } + }) + .or_insert_with(|| { + if self.conditional { + NodeEvaluation::ConditionallyAtLeastOnce + } else { + NodeEvaluation::SurelyOnce + } + }); + } + self.visit_stack + .push(VisitRecord::NodeItem(node_id, is_valid)); + self.up_index += 1; + + Ok(TreeNodeRecursion::Continue) + } +} + +/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the +/// corresponding temporary [`TreeNode`], that column contains the evaluate result of +/// replaced [`TreeNode`] tree. +struct CSERewriter<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a IdArray<'n, N>, + + /// common [`TreeNode`]s, that are replaced during the second traversal, are collected + /// to this map + common_nodes: &'a mut CommonNodes<'n, N>, + + // preorder index, starts from 0. + down_index: usize, + + controller: &'a mut C, +} + +impl> TreeNodeRewriter + for CSERewriter<'_, '_, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_down(&node); + + let (up_index, node_id) = self.id_array[self.down_index]; + self.down_index += 1; + + // Handle nodes with identifiers only + if let Some(node_id) = node_id { + let evaluation = self.node_stats.get(&node_id).unwrap(); + if *evaluation == NodeEvaluation::Common { + // step index to skip all sub-node (which has smaller series number). + while self.down_index < self.id_array.len() + && self.id_array[self.down_index].0 < up_index + { + self.down_index += 1; + } + + let (node, alias) = + self.common_nodes.entry(node_id).or_insert_with(|| { + let node_alias = self.controller.generate_alias(); + (node, node_alias) + }); + + let rewritten = self.controller.rewrite(node, alias); + + return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_up(&node); + + Ok(Transformed::no(node)) + } +} + +/// The main entry point of Common Subexpression Elimination. +/// +/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular +/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the +/// [`CSE::extract_common_nodes()`] method. +pub struct CSE> { + random_state: RandomState, + phantom_data: PhantomData, + controller: C, +} + +impl> CSE { + pub fn new(controller: C) -> Self { + Self { + random_state: RandomState::new(), + phantom_data: PhantomData, + controller, + } + } + + /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. + fn node_to_id_array<'n>( + &self, + node: &'n N, + node_stats: &mut NodeStats<'n, N>, + id_array: &mut IdArray<'n, N>, + ) -> Result { + let mut visitor = CSEVisitor { + node_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + random_state: &self.random_state, + found_common: false, + conditional: false, + controller: &self.controller, + }; + node.visit(&mut visitor)?; + + Ok(visitor.found_common) + } + + /// Returns the identifier list for each element in `nodes` and a flag to indicate if + /// rewrite phase of CSE make sense. + /// + /// Returns and array with 1 element for each input node in `nodes` + /// + /// Each element is itself the result of [`CSE::node_to_id_array`] for that node + /// (e.g. the identifiers for each node in the tree) + fn to_arrays<'n>( + &self, + nodes: &'n [N], + node_stats: &mut NodeStats<'n, N>, + ) -> Result<(bool, Vec>)> { + let mut found_common = false; + nodes + .iter() + .map(|n| { + let mut id_array = vec![]; + self.node_to_id_array(n, node_stats, &mut id_array) + .map(|fc| { + found_common |= fc; + + id_array + }) + }) + .collect::>>() + .map(|id_arrays| (found_common, id_arrays)) + } + + /// Replace common subtrees in `node` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`] + fn replace_common_node<'n>( + &mut self, + node: N, + id_array: &IdArray<'n, N>, + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result { + if id_array.is_empty() { + Ok(Transformed::no(node)) + } else { + node.rewrite(&mut CSERewriter { + node_stats, + id_array, + common_nodes, + down_index: 0, + controller: &mut self.controller, + }) + } + .data() + } + + /// Replace common subtrees in `nodes_list` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]. + fn rewrite_nodes_list<'n>( + &mut self, + nodes_list: Vec>, + arrays_list: &[Vec>], + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result>> { + nodes_list + .into_iter() + .zip(arrays_list.iter()) + .map(|(nodes, arrays)| { + nodes + .into_iter() + .zip(arrays.iter()) + .map(|(node, id_array)| { + self.replace_common_node(node, id_array, node_stats, common_nodes) + }) + .collect::>>() + }) + .collect::>>() + } + + /// Extracts common [`TreeNode`]s and rewrites `nodes_list`. + /// + /// Returns [`FoundCommonNodes`] recording the result of the extraction. + pub fn extract_common_nodes( + &mut self, + nodes_list: Vec>, + ) -> Result> { + let mut found_common = false; + let mut node_stats = NodeStats::new(); + let id_arrays_list = nodes_list + .iter() + .map(|nodes| { + self.to_arrays(nodes, &mut node_stats) + .map(|(fc, id_arrays)| { + found_common |= fc; + + id_arrays + }) + }) + .collect::>>()?; + if found_common { + let mut common_nodes = CommonNodes::new(); + let new_nodes_list = self.rewrite_nodes_list( + // Must clone the list of nodes as Identifiers use references to original + // nodes so we have to keep them intact. + nodes_list.clone(), + &id_arrays_list, + &node_stats, + &mut common_nodes, + )?; + assert!(!common_nodes.is_empty()); + + Ok(FoundCommonNodes::Yes { + common_nodes: common_nodes.into_values().collect(), + new_nodes_list, + original_nodes_list: nodes_list, + }) + } else { + Ok(FoundCommonNodes::No { + original_nodes_list: nodes_list, + }) + } + } +} + +#[cfg(test)] +mod test { + use crate::alias::AliasGenerator; + use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; + use crate::tree_node::tests::TestTreeNode; + use crate::Result; + use std::collections::HashSet; + use std::hash::{Hash, Hasher}; + + const CSE_PREFIX: &str = "__common_node"; + + #[derive(Clone, Copy)] + pub enum TestTreeNodeMask { + Normal, + NormalAndAggregates, + } + + pub struct TestTreeNodeCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: TestTreeNodeMask, + } + + impl<'a> TestTreeNodeCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { + Self { + alias_generator, + mask, + } + } + } + + impl CSEController for TestTreeNodeCSEController<'_> { + type Node = TestTreeNode; + + fn conditional_children( + _: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + None + } + + fn is_valid(_node: &Self::Node) -> bool { + true + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + let is_leaf = node.is_leaf(); + let is_aggr = node.data == "avg" || node.data == "sum"; + + match self.mask { + TestTreeNodeMask::Normal => is_leaf || is_aggr, + TestTreeNodeMask::NormalAndAggregates => is_leaf, + } + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) + } + } + + impl HashNode for TestTreeNode { + fn hash_node(&self, state: &mut H) { + self.data.hash(state); + } + } + + #[test] + fn id_array_visitor() -> Result<()> { + let alias_generator = AliasGenerator::new(); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::Normal, + )); + + let a_plus_1 = TestTreeNode::new( + vec![ + TestTreeNode::new_leaf("a".to_string()), + TestTreeNode::new_leaf("1".to_string()), + ], + "+".to_string(), + ); + let avg_c = TestTreeNode::new( + vec![TestTreeNode::new_leaf("c".to_string())], + "avg".to_string(), + ); + let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string()); + let sum_a_plus_1_minus_avg_c = + TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string()); + let root = TestTreeNode::new( + vec![ + sum_a_plus_1_minus_avg_c, + TestTreeNode::new_leaf("2".to_string()), + ], + "*".to_string(), + ); + + let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [a_plus_1] = sum_a_plus_1.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + + // skip aggregates + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + // Collect distinct hashes and set them to 0 in `id_array` + fn collect_hashes( + id_array: &mut IdArray<'_, TestTreeNode>, + ) -> HashSet { + id_array + .iter_mut() + .flat_map(|(_, id_option)| { + id_option.as_mut().map(|node_id| { + let hash = node_id.hash; + node_id.hash = 0; + hash + }) + }) + .collect::>() + } + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 3); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + (3, None), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + (5, None), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + // include aggregates + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::NormalAndAggregates, + )); + + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 5); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + ( + 3, + Some(Identifier { + hash: 0, + node: sum_a_plus_1, + }), + ), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + ( + 5, + Some(Identifier { + hash: 0, + node: avg_c, + }), + ), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + Ok(()) + } +} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 0dec14e9178a5..aa2d93989da19 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -226,7 +226,12 @@ impl DFSchema { for (field, qualifier) in self.inner.fields().iter().zip(&self.field_qualifiers) { if let Some(qualifier) = qualifier { - qualified_names.insert((qualifier, field.name())); + if !qualified_names.insert((qualifier, field.name())) { + return _schema_err!(SchemaError::DuplicateQualifiedField { + qualifier: Box::new(qualifier.clone()), + name: field.name().to_string(), + }); + } } else if !unqualified_names.insert(field.name()) { return _schema_err!(SchemaError::DuplicateUnqualifiedField { name: field.name().to_string() @@ -310,7 +315,6 @@ impl DFSchema { None => self_unqualified_names.contains(field.name().as_str()), }; if !duplicated_field { - // self.inner.fields.push(field.clone()); schema_builder.push(Arc::clone(field)); qualifiers.push(qualifier.cloned()); } @@ -401,33 +405,6 @@ impl DFSchema { } } - /// Check whether the column reference is ambiguous - pub fn check_ambiguous_name( - &self, - qualifier: Option<&TableReference>, - name: &str, - ) -> Result<()> { - let count = self - .iter() - .filter(|(field_q, f)| match (field_q, qualifier) { - (Some(q1), Some(q2)) => q1.resolved_eq(q2) && f.name() == name, - (None, None) => f.name() == name, - _ => false, - }) - .take(2) - .count(); - if count > 1 { - _schema_err!(SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, - }) - } else { - Ok(()) - } - } - /// Find the qualified field with the given name pub fn qualified_field_with_name( &self, @@ -1165,7 +1142,10 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let join = left.join(&right); - assert!(join.err().is_none()); + assert_eq!( + join.unwrap_err().strip_backtrace(), + "Schema error: Schema contains duplicate qualified field name t1.c0", + ); Ok(()) } diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 5d553d59da4ec..dd9d67d6bb47f 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -176,6 +176,7 @@ impl ParquetOptions { maximum_buffered_record_batches_per_stream: _, bloom_filter_on_read: _, // reads not used for writer props schema_force_view_types: _, + binary_as_string: _, // not used for writer props } = self; let mut builder = WriterProperties::builder() @@ -442,6 +443,7 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: defaults.bloom_filter_on_read, schema_force_view_types: defaults.schema_force_view_types, + binary_as_string: defaults.binary_as_string, } } @@ -543,6 +545,7 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, schema_force_view_types: global_options_defaults.schema_force_view_types, + binary_as_string: global_options_defaults.binary_as_string, }, column_specific_options, key_value_metadata, diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 90f4e6e7e3d1e..ed9a68c19536c 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -23,11 +23,8 @@ use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::vec::IntoIter; -use crate::error::_plan_err; use crate::utils::{merge_and_order_indices, set_difference}; -use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; - -use sqlparser::ast::TableConstraint; +use crate::{DFSchema, JoinType}; /// This object defines a constraint on a table. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -60,74 +57,6 @@ impl Constraints { Self { inner: constraints } } - /// Convert each `TableConstraint` to corresponding `Constraint` - pub fn new_from_table_constraints( - constraints: &[TableConstraint], - df_schema: &DFSchemaRef, - ) -> Result { - let constraints = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let field_names = df_schema.field_names(); - // Get unique constraint indices in the schema: - let indices = columns - .iter() - .map(|u| { - let idx = field_names - .iter() - .position(|item| *item == u.value) - .ok_or_else(|| { - let name = name - .as_ref() - .map(|name| format!("with name '{name}' ")) - .unwrap_or("".to_string()); - DataFusionError::Execution( - format!("Column for unique constraint {}not found in schema: {}", name,u.value) - ) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::Unique(indices)) - } - TableConstraint::PrimaryKey { columns, .. } => { - let field_names = df_schema.field_names(); - // Get primary key indices in the schema: - let indices = columns - .iter() - .map(|pk| { - let idx = field_names - .iter() - .position(|item| *item == pk.value) - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Column for primary key not found in schema: {}", - pk.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::PrimaryKey(indices)) - } - TableConstraint::ForeignKey { .. } => { - _plan_err!("Foreign key constraints are not currently supported") - } - TableConstraint::Check { .. } => { - _plan_err!("Check constraints are not currently supported") - } - TableConstraint::Index { .. } => { - _plan_err!("Indexes are not currently supported") - } - TableConstraint::FulltextOrSpatial { .. } => { - _plan_err!("Indexes are not currently supported") - } - }) - .collect::>>()?; - Ok(Constraints::new_unverified(constraints)) - } - /// Check whether constraints is empty pub fn is_empty(&self) -> bool { self.inner.is_empty() diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 72cfeafd0bfec..8bd646626e068 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -102,8 +102,7 @@ fn hash_array_primitive( hashes_buffer: &mut [u64], rehash: bool, ) where - T: ArrowPrimitiveType, - ::Native: HashValue, + T: ArrowPrimitiveType, { assert_eq!( hashes_buffer.len(), diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index fbdae1c50a83e..d502e7836da3a 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -97,7 +97,7 @@ pub enum JoinConstraint { } impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { JoinSide::Left => write!(f, "left"), JoinSide::Right => write!(f, "right"), diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 10541e01914ad..e4575038ab988 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -31,6 +31,7 @@ mod unnest; pub mod alias; pub mod cast; pub mod config; +pub mod cse; pub mod display; pub mod error; pub mod file_options; @@ -70,7 +71,7 @@ pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{ResolvedTableReference, TableReference}; -pub use unnest::UnnestOptions; +pub use unnest::{RecursionUnnestOption, UnnestOptions}; pub use utils::project_schema; // These are hidden from docs purely to avoid polluting the public view of what this crate exports. diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index e23edb4e2adb7..c73c8a55f18c5 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -18,7 +18,6 @@ //! Interval parsing logic use std::fmt::Display; -use std::result; use std::str::FromStr; use sqlparser::parser::ParserError; @@ -41,7 +40,7 @@ pub enum CompressionTypeVariant { impl FromStr for CompressionTypeVariant { type Err = ParserError; - fn from_str(s: &str) -> result::Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 87254a499fb11..bdcf831c7884b 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -34,7 +34,7 @@ impl From for PyErr { } impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &pyo3::Bound<'_, pyo3::PyAny>) -> PyResult { + fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); let typ = value.getattr("type")?; let val = value.call_method0("as_py")?; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3356a85fb6d47..7a1eaa2ad65b0 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -28,6 +28,7 @@ use std::fmt; use std::hash::Hash; use std::hash::Hasher; use std::iter::repeat; +use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; @@ -58,6 +59,7 @@ use arrow::{ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::{UnionFields, UnionMode}; +use crate::format::DEFAULT_CAST_OPTIONS; use half::f16; pub use struct_builder::ScalarStructBuilder; @@ -690,8 +692,8 @@ hash_float_value!((f64, u64), (f32, u32)); // # Panics // // Panics if there is an error when creating hash values for rows -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { use ScalarValue::*; match self { Decimal128(v, p, s) => { @@ -767,7 +769,7 @@ impl std::hash::Hash for ScalarValue { } } -fn hash_nested_array(arr: ArrayRef, state: &mut H) { +fn hash_nested_array(arr: ArrayRef, state: &mut H) { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -801,7 +803,7 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(if value.is_null() { + let key_array: PrimitiveArray = repeat(if value.is_null() { None } else { Some(K::default_value()) @@ -2042,7 +2044,7 @@ impl ScalarValue { scale: i8, size: usize, ) -> Result { - Ok(std::iter::repeat(value) + Ok(repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale)?) @@ -2511,7 +2513,7 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr).take(size).collect::>(); + let arrays = repeat(arr).take(size).collect::>(); let ret = match !arrays.is_empty() { true => arrow::compute::concat(arrays.as_slice())?, false => arr.slice(0, 0), @@ -2809,22 +2811,30 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::from(value); - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), - }; - let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; - ScalarValue::try_from_array(&cast_arr, 0) + ScalarValue::from(value).cast_to(target_type) } /// Try to cast this value to a ScalarValue of type `data_type` - pub fn cast_to(&self, data_type: &DataType) -> Result { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), + pub fn cast_to(&self, target_type: &DataType) -> Result { + self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS) + } + + /// Try to cast this value to a ScalarValue of type `data_type` with [`CastOptions`] + pub fn cast_to_with_options( + &self, + target_type: &DataType, + cast_options: &CastOptions<'static>, + ) -> Result { + let scalar_array = match (self, target_type) { + ( + ScalarValue::Float64(Some(float_ts)), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) + .to_array()?, + _ => self.to_array()?, }; - let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?; + + let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -3074,7 +3084,7 @@ impl ScalarValue { /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + match self { ScalarValue::Null | ScalarValue::Boolean(_) @@ -3128,12 +3138,12 @@ impl ScalarValue { ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() - .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .map(|(_id, sv)| sv.size() - size_of_val(sv)) .unwrap_or_default() // `fields` is boxed, so it is NOT already included in `self` - + std::mem::size_of_val(fields) - + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + + size_of_val(fields) + + (size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - size_of_val(field)).sum::() } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` @@ -3146,11 +3156,11 @@ impl ScalarValue { /// /// Includes the size of the [`Vec`] container itself. pub fn size_of_vec(vec: &Vec) -> usize { - std::mem::size_of_val(vec) - + (std::mem::size_of::() * vec.capacity()) + size_of_val(vec) + + (size_of::() * vec.capacity()) + vec .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3158,11 +3168,11 @@ impl ScalarValue { /// /// Includes the size of the [`VecDeque`] container itself. pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { - std::mem::size_of_val(vec_deque) - + (std::mem::size_of::() * vec_deque.capacity()) + size_of_val(vec_deque) + + (size_of::() * vec_deque.capacity()) + vec_deque .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3170,11 +3180,11 @@ impl ScalarValue { /// /// Includes the size of the [`HashSet`] container itself. pub fn size_of_hashset(set: &HashSet) -> usize { - std::mem::size_of_val(set) - + (std::mem::size_of::() * set.capacity()) + size_of_val(set) + + (size_of::() * set.capacity()) + set .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } } @@ -3577,9 +3587,8 @@ impl fmt::Display for ScalarValue { columns .iter() .zip(fields.iter()) - .enumerate() - .map(|(index, (column, field))| { - if nulls.is_some_and(|b| b.is_null(index)) { + .map(|(column, field)| { + if nulls.is_some_and(|b| b.is_null(0)) { format!("{}:NULL", field.name()) } else if let DataType::Struct(_) = field.data_type() { let sv = ScalarValue::Struct(Arc::new( @@ -3875,7 +3884,7 @@ mod tests { use arrow::compute::{is_null, kernels}; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; - use arrow_buffer::Buffer; + use arrow_buffer::{Buffer, NullBuffer}; use arrow_schema::Fields; use chrono::NaiveDate; use rand::Rng; @@ -4437,7 +4446,7 @@ mod tests { let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); - let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); + let arrow_result = add(arrow_left_array, arrow_right_array); assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } @@ -5052,13 +5061,13 @@ mod tests { // thus the size of the enum appears to as well // The value may also change depending on rust version - assert_eq!(std::mem::size_of::(), 64); + assert_eq!(size_of::(), 64); } #[test] fn memory_size() { let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); - assert_eq!(sv.size(), std::mem::size_of::() + 10,); + assert_eq!(sv.size(), size_of::() + 10,); let sv_size = sv.size(); let mut v = Vec::with_capacity(10); @@ -5067,9 +5076,7 @@ mod tests { assert_eq!(v.capacity(), 10); assert_eq!( ScalarValue::size_of_vec(&v), - std::mem::size_of::>() - + (9 * std::mem::size_of::()) - + sv_size, + size_of::>() + (9 * size_of::()) + sv_size, ); let mut s = HashSet::with_capacity(0); @@ -5079,8 +5086,8 @@ mod tests { let s_capacity = s.capacity(); assert_eq!( ScalarValue::size_of_hashset(&s), - std::mem::size_of::>() - + ((s_capacity - 1) * std::mem::size_of::()) + size_of::>() + + ((s_capacity - 1) * size_of::()) + sv_size, ); } @@ -6589,6 +6596,43 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_null_bug() { + let field_a = Field::new("a", DataType::Int32, true); + let field_b = Field::new("b", DataType::Int32, true); + let fields = Fields::from(vec![field_a, field_b]); + + let array_a = Arc::new(Int32Array::from_iter_values([1])); + let array_b = Arc::new(Int32Array::from_iter_values([2])); + let arrays: Vec = vec![array_a, array_b]; + + let mut not_nulls = BooleanBufferBuilder::new(1); + not_nulls.append(true); + let not_nulls = not_nulls.finish(); + let not_nulls = Some(NullBuffer::new(not_nulls)); + + let ar = StructArray::new(fields, arrays, not_nulls); + let s = ScalarValue::Struct(Arc::new(ar)); + + assert_eq!(s.to_string(), "{a:1,b:2}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:2})"#); + + let ScalarValue::Struct(arr) = s else { + panic!("Expected struct"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); + let expected = [ + "+--------------+", + "| s |", + "+--------------+", + "| {a: 1, b: 2} |", + "+--------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index d8e62b3045f93..e669c674f78a2 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -190,7 +190,7 @@ impl Precision { } } -impl Debug for Precision { +impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -200,7 +200,7 @@ impl Debug for Precision } } -impl Display for Precision { +impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -341,7 +341,7 @@ fn check_num_rows(value: Option, is_exact: bool) -> Precision { } impl Display for Statistics { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // string of column statistics let column_stats = self .column_statistics diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 36254192550c8..d3b8c84512583 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -279,8 +279,88 @@ pub fn get_data_dir( } } +#[macro_export] +macro_rules! create_array { + (Boolean, $values: expr) => { + std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + }; + (Int8, $values: expr) => { + std::sync::Arc::new(arrow::array::Int8Array::from($values)) + }; + (Int16, $values: expr) => { + std::sync::Arc::new(arrow::array::Int16Array::from($values)) + }; + (Int32, $values: expr) => { + std::sync::Arc::new(arrow::array::Int32Array::from($values)) + }; + (Int64, $values: expr) => { + std::sync::Arc::new(arrow::array::Int64Array::from($values)) + }; + (UInt8, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + }; + (UInt16, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + }; + (UInt32, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + }; + (UInt64, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + }; + (Float16, $values: expr) => { + std::sync::Arc::new(arrow::array::Float16Array::from($values)) + }; + (Float32, $values: expr) => { + std::sync::Arc::new(arrow::array::Float32Array::from($values)) + }; + (Float64, $values: expr) => { + std::sync::Arc::new(arrow::array::Float64Array::from($values)) + }; + (Utf8, $values: expr) => { + std::sync::Arc::new(arrow::array::StringArray::from($values)) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// ``` +/// use datafusion_common::{record_batch, create_array}; +/// let batch = record_batch!( +/// ("a", Int32, vec![1, 2, 3]), +/// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, vec!["alpha", "beta", "gamma"]) +/// ); +/// ``` +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, $values: expr)),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = arrow_array::RecordBatch::try_new( + schema, + vec![$( + $crate::create_array!($type, $values), + )*] + ); + + batch + } + } +} + #[cfg(test)] mod tests { + use crate::cast::{as_float64_array, as_int32_array, as_string_array}; + use crate::error::Result; + use super::*; use std::env; @@ -333,4 +413,44 @@ mod tests { let res = parquet_test_data(); assert!(PathBuf::from(res).is_dir()); } + + #[test] + fn test_create_record_batch() -> Result<()> { + use arrow_array::Array; + + let batch = record_batch!( + ("a", Int32, vec![1, 2, 3, 4]), + ("b", Float64, vec![Some(4.0), None, Some(5.0), None]), + ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"]) + )?; + + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let values: Vec<_> = as_int32_array(batch.column(0))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![1, 2, 3, 4]); + + let values: Vec<_> = as_float64_array(batch.column(1))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]); + + let nulls: Vec<_> = as_float64_array(batch.column(1))? + .nulls() + .unwrap() + .iter() + .collect(); + assert_eq!(nulls, vec![true, false, true, false]); + + let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect(); + assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]); + + Ok(()) + } } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 88300e3edd0ee..563f1fa85614b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -681,7 +681,7 @@ impl Transformed { } } - /// Create a `Transformed` with `transformed and [`TreeNodeRecursion::Continue`]. + /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`]. pub fn new_transformed(data: T, transformed: bool) -> Self { Self::new(data, transformed, TreeNodeRecursion::Continue) } @@ -1027,7 +1027,7 @@ impl TreeNode for T { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; @@ -1037,16 +1037,27 @@ mod tests { }; use crate::Result; - #[derive(Debug, Eq, Hash, PartialEq)] - struct TestTreeNode { - children: Vec>, - data: T, + #[derive(Debug, Eq, Hash, PartialEq, Clone)] + pub struct TestTreeNode { + pub(crate) children: Vec>, + pub(crate) data: T, } impl TestTreeNode { - fn new(children: Vec>, data: T) -> Self { + pub(crate) fn new(children: Vec>, data: T) -> Self { Self { children, data } } + + pub(crate) fn new_leaf(data: T) -> Self { + Self { + children: vec![], + data, + } + } + + pub(crate) fn is_leaf(&self) -> bool { + self.children.is_empty() + } } impl TreeNode for TestTreeNode { @@ -1086,12 +1097,12 @@ mod tests { // | // A fn test_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1130,13 +1141,13 @@ mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1146,12 +1157,12 @@ mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1160,12 +1171,12 @@ mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1202,12 +1213,12 @@ mod tests { } fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1236,12 +1247,12 @@ mod tests { } fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1250,12 +1261,12 @@ mod tests { } fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1289,12 +1300,12 @@ mod tests { } fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1303,12 +1314,12 @@ mod tests { } fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1372,12 +1383,12 @@ mod tests { } fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1385,12 +1396,12 @@ mod tests { } fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1406,12 +1417,12 @@ mod tests { } fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1419,12 +1430,12 @@ mod tests { } fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1451,12 +1462,12 @@ mod tests { } fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1464,12 +1475,12 @@ mod tests { } fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1499,13 +1510,13 @@ mod tests { } fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1513,12 +1524,12 @@ mod tests { } fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -2016,16 +2027,16 @@ mod tests { // A #[test] fn test_apply_and_visit_references() -> Result<()> { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_a_2 = TestTreeNode::new(vec![], "a".to_string()); - let node_b_2 = TestTreeNode::new(vec![], "b".to_string()); + let node_a_2 = TestTreeNode::new_leaf("a".to_string()); + let node_b_2 = TestTreeNode::new_leaf("b".to_string()); let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string()); let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string()); - let node_a_3 = TestTreeNode::new(vec![], "a".to_string()); + let node_a_3 = TestTreeNode::new_leaf("a".to_string()); let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string()); let node_f_ref = &tree; diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs index fd92267f9b4c3..db48edd061605 100644 --- a/datafusion/common/src/unnest.rs +++ b/datafusion/common/src/unnest.rs @@ -17,6 +17,8 @@ //! [`UnnestOptions`] for unnesting structured types +use crate::Column; + /// Options for unnesting a column that contains a list type, /// replicating values in the other, non nested rows. /// @@ -60,10 +62,27 @@ /// └─────────┘ └─────┘ └─────────┘ └─────┘ /// c1 c2 c1 c2 /// ``` +/// +/// `recursions` instruct how a column should be unnested (e.g unnesting a column multiple +/// time, with depth = 1 and depth = 2). Any unnested column not being mentioned inside this +/// options is inferred to be unnested with depth = 1 #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] pub struct UnnestOptions { /// Should nulls in the input be preserved? Defaults to true pub preserve_nulls: bool, + /// If specific columns need to be unnested multiple times (e.g at different depth), + /// declare them here. Any unnested columns not being mentioned inside this option + /// will be unnested with depth = 1 + pub recursions: Vec, +} + +/// Instruction on how to unnest a column (mostly with a list type) +/// such as how to name the output, and how many level it should be unnested +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct RecursionUnnestOption { + pub input_column: Column, + pub output_column: Column, + pub depth: usize, } impl Default for UnnestOptions { @@ -71,6 +90,7 @@ impl Default for UnnestOptions { Self { // default to true to maintain backwards compatible behavior preserve_nulls: true, + recursions: vec![], } } } @@ -87,4 +107,10 @@ impl UnnestOptions { self.preserve_nulls = preserve_nulls; self } + + /// Set the recursions for the unnest operation + pub fn with_recursions(mut self, recursion: RecursionUnnestOption) -> Self { + self.recursions.push(recursion); + self + } } diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index 2c34b61bd0930..d5ce59e3421b9 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -18,6 +18,7 @@ //! This module provides a function to estimate the memory size of a HashTable prior to alloaction use crate::{DataFusionError, Result}; +use std::mem::size_of; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -87,7 +88,7 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result // + size of entry * number of buckets // + 1 byte for each bucket // + fixed size of collection (HashSet/HashTable) - std::mem::size_of::() + size_of::() .checked_mul(estimated_buckets)? .checked_add(estimated_buckets)? .checked_add(fixed_size) @@ -108,7 +109,7 @@ mod tests { #[test] fn test_estimate_memory() { // size (bytes): 48 - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two() let num_elements = 8; @@ -126,7 +127,7 @@ mod tests { #[test] fn test_estimate_memory_overflow() { let num_elements = usize::MAX; - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); let estimated = estimate_memory_size::(num_elements, fixed_size); assert!(estimated.is_err()); diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 83f98ff9aff6a..dacf90af9bbfc 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -23,17 +23,14 @@ pub mod proxy; pub mod string_utils; use crate::error::{_internal_datafusion_err, _internal_err}; -use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use arrow::array::{ArrayRef, PrimitiveArray}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow::array::ArrayRef; use arrow::buffer::OffsetBuffer; -use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{Field, SchemaRef, UInt32Type}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::{Field, SchemaRef}; use arrow_array::cast::AsArray; use arrow_array::{ Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, - RecordBatchOptions, }; use arrow_schema::DataType; use sqlparser::ast::Ident; @@ -93,20 +90,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result, -) -> Result { - let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new_with_options( - record_batch.schema(), - new_columns, - &RecordBatchOptions::new().with_row_count(Some(indices.len())), - ) - .map_err(|e| arrow_datafusion_err!(e)) -} - /// This function compares two tuples depending on the given sort options. pub fn compare_rows( x: &[ScalarValue], @@ -290,24 +273,6 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } -/// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. -pub fn get_arrayref_at_indices( - arrays: &[ArrayRef], - indices: &PrimitiveArray, -) -> Result> { - arrays - .iter() - .map(|array| { - compute::take( - array.as_ref(), - indices, - None, // None: no index check - ) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect() -} - pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -1003,39 +968,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_arrayref_at_indices() -> Result<()> { - let arrays: Vec = vec![ - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), - Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), - Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), - ]; - - let row_indices_vec: Vec> = vec![ - // Get rows 0 and 1 - vec![0, 1], - // Get rows 0 and 1 - vec![0, 2], - // Get rows 1 and 3 - vec![1, 3], - // Get rows 2 and 4 - vec![2, 4], - ]; - for row_indices in row_indices_vec { - let indices = PrimitiveArray::from_iter_values(row_indices.iter().cloned()); - let chunk = get_arrayref_at_indices(&arrays, &indices)?; - for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { - for (idx, orig_idx) in row_indices.iter().enumerate() { - let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; - let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; - assert_eq!(res1, res2); - } - } - } - Ok(()) - } - #[test] fn test_get_at_indices() -> Result<()> { let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index d68b5e354384a..5d14a15171295 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -18,6 +18,7 @@ //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations use hashbrown::raw::{Bucket, RawTable}; +use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. pub trait VecAllocExt { @@ -93,7 +94,7 @@ impl VecAllocExt for Vec { let new_capacity = self.capacity(); if new_capacity > prev_capacty { // capacity changed, so we allocated more - let bump_size = (new_capacity - prev_capacty) * std::mem::size_of::(); + let bump_size = (new_capacity - prev_capacty) * size_of::(); // Note multiplication should never overflow because `push` would // have panic'd first, but the checked_add could potentially // overflow since accounting could be tracking additional values, and @@ -102,7 +103,7 @@ impl VecAllocExt for Vec { } } fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } @@ -157,7 +158,7 @@ impl RawTableAllocExt for RawTable { // need to request more memory let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * std::mem::size_of::(); + let bump_size = bump_elements * size_of::(); *accounting = (*accounting).checked_add(bump_size).expect("overflow"); self.reserve(bump_elements, hasher); diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 01ba90ee5de87..8c4ad80e29245 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.78" +rust-version = "1.79" [lints] workspace = true @@ -67,8 +67,6 @@ math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ - "datafusion-physical-expr/regex_expressions", - "datafusion-optimizer/regex_expressions", "datafusion-functions/regex_expressions", ] serde = ["arrow-schema/serde"] diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index bc4298786002e..f82a126c56520 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -249,7 +249,7 @@ fn criterion_benchmark(c: &mut Criterion) { } // Temporary file must outlive the benchmarks, it is deleted when dropped - std::mem::drop(temp_file); + drop(temp_file); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 00f6d5916751b..140e266a02720 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,22 +15,31 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -extern crate arrow; extern crate datafusion; mod data_utils; + use crate::criterion::Criterion; use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; +use itertools::Itertools; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; use test_utils::TableDef; use tokio::runtime::Runtime; +const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; +const BENCHMARKS_PATH_2: &str = "./benchmarks/"; +const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; + /// Create a logical plan from the specified sql fn logical_plan(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); @@ -60,7 +69,9 @@ fn create_schema(column_prefix: &str, num_columns: usize) -> Schema { fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc { let schema = Arc::new(create_schema(column_prefix, num_columns)); - MemTable::try_new(schema, vec![]).map(Arc::new).unwrap() + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() } fn create_context() -> SessionContext { @@ -89,7 +100,37 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { ctx } +fn register_clickbench_hits_table() -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + + // use an external table for clickbench benchmarks + let path = + if PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() { + format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}") + } else { + format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}") + }; + + let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + + rt.block_on(ctx.sql(&sql)).unwrap(); + + let count = + rt.block_on(async { ctx.table("hits").await.unwrap().count().await.unwrap() }); + assert!(count > 0); + ctx +} + fn criterion_benchmark(c: &mut Criterion) { + // verify that we can load the clickbench data prior to running the benchmark + if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() + && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() + { + panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + } + let ctx = create_context(); // Test simplest @@ -144,6 +185,85 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("physical_select_aggregates_from_200", |b| { + let mut aggregates = String::new(); + for i in 0..200 { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + let query = format!("SELECT {} FROM t1", aggregates); + b.iter(|| { + physical_plan(&ctx, &query); + }); + }); + + // Benchmark for Physical Planning Joins + c.bench_function("physical_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 = b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_theta_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 < b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_many_self_joins", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT ta.a9, tb.a10, tc.a11, td.a12, te.a13, tf.a14 \ + FROM t1 AS ta, t1 AS tb, t1 AS tc, t1 AS td, t1 AS te, t1 AS tf \ + WHERE ta.a9 = tb.a10 AND tb.a10 = tc.a11 AND tc.a11 = td.a12 AND \ + td.a12 = te.a13 AND te.a13 = tf.a14", + ); + }); + }); + + c.bench_function("physical_unnest_to_join", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 \ + FROM t1 WHERE a7 = (SELECT b8 FROM t2)", + ); + }); + }); + + c.bench_function("physical_intersection", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 FROM t1 \ + INTERSECT SELECT t2.b8 FROM t2", + ); + }); + }); + // these two queries should be equivalent + c.bench_function("physical_join_distinct", |b| { + b.iter(|| { + logical_plan( + &ctx, + "SELECT DISTINCT t1.a7 \ + FROM t1, t2 WHERE t1.a7 = t2.b8", + ); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); @@ -154,9 +274,15 @@ fn criterion_benchmark(c: &mut Criterion) { "q16", "q17", "q18", "q19", "q20", "q21", "q22", ]; + let benchmarks_path = if PathBuf::from(BENCHMARKS_PATH_1).exists() { + BENCHMARKS_PATH_1 + } else { + BENCHMARKS_PATH_2 + }; + for q in tpch_queries { let sql = - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap(); + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { b.iter(|| physical_plan(&tpch_ctx, &sql)) }); @@ -165,7 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { let all_tpch_sql_queries = tpch_queries .iter() .map(|q| { - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap() + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap() }) .collect::>(); @@ -177,26 +303,25 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpch_all", |b| { - b.iter(|| { - for sql in &all_tpch_sql_queries { - logical_plan(&tpch_ctx, sql) - } - }) - }); + // c.bench_function("logical_plan_tpch_all", |b| { + // b.iter(|| { + // for sql in &all_tpch_sql_queries { + // logical_plan(&tpch_ctx, sql) + // } + // }) + // }); // --- TPC-DS --- let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); - - // 10, 35: Physical plan does not support logical expression Exists() - // 45: Physical plan does not support logical expression () - // 41: Optimizing disjunctions not supported - let ignored = [10, 35, 41, 45]; + let tests_path = if PathBuf::from("./tests/").exists() { + "./tests/" + } else { + "datafusion/core/tests/" + }; let raw_tpcds_sql_queries = (1..100) - .filter(|q| !ignored.contains(q)) - .map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap()) + .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); // some queries have multiple statements @@ -213,10 +338,53 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpcds_all", |b| { + // c.bench_function("logical_plan_tpcds_all", |b| { + // b.iter(|| { + // for sql in &all_tpcds_sql_queries { + // logical_plan(&tpcds_ctx, sql) + // } + // }) + // }); + + // -- clickbench -- + + let queries_file = + File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); + let extended_file = + File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); + + let clickbench_queries: Vec = BufReader::new(queries_file) + .lines() + .chain(BufReader::new(extended_file).lines()) + .map(|l| l.expect("Could not parse line")) + .collect_vec(); + + let clickbench_ctx = register_clickbench_hits_table(); + + // for (i, sql) in clickbench_queries.iter().enumerate() { + // c.bench_function(&format!("logical_plan_clickbench_q{}", i + 1), |b| { + // b.iter(|| logical_plan(&clickbench_ctx, sql)) + // }); + // } + + for (i, sql) in clickbench_queries.iter().enumerate() { + c.bench_function(&format!("physical_plan_clickbench_q{}", i + 1), |b| { + b.iter(|| physical_plan(&clickbench_ctx, sql)) + }); + } + + // c.bench_function("logical_plan_clickbench_all", |b| { + // b.iter(|| { + // for sql in &clickbench_queries { + // logical_plan(&clickbench_ctx, sql) + // } + // }) + // }); + + c.bench_function("physical_plan_clickbench_all", |b| { b.iter(|| { - for sql in &all_tpcds_sql_queries { - logical_plan(&tpcds_ctx, sql) + for sql in &clickbench_queries { + physical_plan(&clickbench_ctx, sql) } }) }); diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs new file mode 100644 index 0000000000000..3aedcbc2aa63e --- /dev/null +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::execution::SessionStateDefaults; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, + DocSection, Documentation, ScalarUDF, WindowUDF, +}; +use hashbrown::HashSet; +use itertools::Itertools; +use std::env::args; +use std::fmt::Write as _; + +/// Print documentation for all functions of a given type to stdout +/// +/// Usage: `cargo run --bin print_functions_docs -- ` +/// +/// Called from `dev/update_function_docs.sh` +fn main() -> Result<()> { + let args: Vec = args().collect(); + + if args.len() != 2 { + panic!( + "Usage: {} type (one of 'aggregate', 'scalar', 'window')", + args[0] + ); + } + + let function_type = args[1].trim().to_lowercase(); + let docs = match function_type.as_str() { + "aggregate" => print_aggregate_docs(), + "scalar" => print_scalar_docs(), + "window" => print_window_docs(), + _ => { + panic!("Unknown function type: {}", function_type) + } + }?; + + println!("{docs}"); + Ok(()) +} + +fn print_aggregate_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_aggregate_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, aggregate_doc_sections::doc_sections()) +} + +fn print_scalar_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_scalar_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, scalar_doc_sections::doc_sections()) +} + +fn print_window_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_window_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, window_doc_sections::doc_sections()) +} + +fn print_docs( + providers: Vec>, + doc_sections: Vec, +) -> Result { + let mut docs = "".to_string(); + + // Ensure that all providers have documentation + let mut providers_with_no_docs = HashSet::new(); + + // doc sections only includes sections that have 'include' == true + for doc_section in doc_sections { + // make sure there is at least one function that is in this doc section + if !&providers.iter().any(|f| { + if let Some(documentation) = f.get_documentation() { + documentation.doc_section == doc_section + } else { + false + } + }) { + continue; + } + + // filter out functions that are not in this doc section + let providers: Vec<&Box> = providers + .iter() + .filter(|&f| { + if let Some(documentation) = f.get_documentation() { + documentation.doc_section == doc_section + } else { + providers_with_no_docs.insert(f.get_name()); + false + } + }) + .collect::>(); + + // write out section header + let _ = writeln!(docs, "\n## {} \n", doc_section.label); + + if let Some(description) = doc_section.description { + let _ = writeln!(docs, "{description}"); + } + + // names is a sorted list of function names and aliases since we display + // both in the documentation + let names = get_names_and_aliases(&providers); + + // write out the list of function names and aliases + names.iter().for_each(|name| { + let _ = writeln!(docs, "- [{name}](#{name})"); + }); + + // write out each function and alias in the order of the sorted name list + for name in names { + let f = providers + .iter() + .find(|f| f.get_name() == name || f.get_aliases().contains(&name)) + .unwrap(); + + let aliases = f.get_aliases(); + let documentation = f.get_documentation(); + + // if this name is an alias we need to display what it's an alias of + if aliases.contains(&name) { + let fname = f.get_name(); + let _ = writeln!(docs, r#"### `{name}`"#); + let _ = writeln!(docs, "_Alias of [{fname}](#{fname})._"); + continue; + } + + // otherwise display the documentation for the function + let Some(documentation) = documentation else { + unreachable!() + }; + + // first, the name, description and syntax example + let _ = write!( + docs, + r#" +### `{}` + +{} + +``` +{} +``` +"#, + name, documentation.description, documentation.syntax_example + ); + + // next, arguments + if let Some(args) = &documentation.arguments { + let _ = writeln!(docs, "#### Arguments\n"); + for (arg_name, arg_desc) in args { + let _ = writeln!(docs, "- **{arg_name}**: {arg_desc}"); + } + } + + // next, sql example if provided + if let Some(example) = &documentation.sql_example { + let _ = writeln!( + docs, + r#" +#### Example + +{} +"#, + example + ); + } + + if let Some(alt_syntax) = &documentation.alternative_syntax { + let _ = writeln!(docs, "#### Alternative Syntax\n"); + for syntax in alt_syntax { + let _ = writeln!(docs, "```sql\n{}\n```", syntax); + } + } + + // next, aliases + if !f.get_aliases().is_empty() { + let _ = writeln!(docs, "#### Aliases"); + + for alias in f.get_aliases() { + let _ = writeln!(docs, "- {}", alias.replace("_", r#"\_"#)); + } + } + + // finally, any related udfs + if let Some(related_udfs) = &documentation.related_udfs { + let _ = writeln!(docs, "\n**Related functions**:"); + + for related in related_udfs { + let _ = writeln!(docs, "- [{related}](#{related})"); + } + } + } + } + + // If there are any functions that do not have documentation, print them out + // eventually make this an error: https://github.com/apache/datafusion/issues/12872 + if !providers_with_no_docs.is_empty() { + eprintln!("INFO: The following functions do not have documentation:"); + for f in &providers_with_no_docs { + eprintln!(" - {f}"); + } + not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + } else { + Ok(docs) + } +} + +/// Trait for accessing name / aliases / documentation for differnet functions +trait DocProvider { + fn get_name(&self) -> String; + fn get_aliases(&self) -> Vec; + fn get_documentation(&self) -> Option<&Documentation>; +} + +impl DocProvider for AggregateUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +impl DocProvider for ScalarUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +impl DocProvider for WindowUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +#[allow(clippy::borrowed_box)] +#[allow(clippy::ptr_arg)] +fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { + functions + .iter() + .flat_map(|f| { + if f.get_aliases().is_empty() { + vec![f.get_name().to_string()] + } else { + let mut names = vec![f.get_name().to_string()]; + names.extend(f.get_aliases().iter().cloned()); + names + } + }) + .sorted() + .collect_vec() +} diff --git a/datafusion/core/src/catalog_common/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs index df4257504b1d8..180994b1cbe89 100644 --- a/datafusion/core/src/catalog_common/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -26,7 +26,7 @@ use arrow::{ }; use async_trait::async_trait; use datafusion_common::DataFusionError; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::{any::Any, sync::Arc}; use crate::catalog::{CatalogProviderList, SchemaProvider, TableProvider}; @@ -57,6 +57,7 @@ pub const INFORMATION_SCHEMA_TABLES: &[&str] = /// demand. This means that if more tables are added to the underlying /// providers, they will appear the next time the `information_schema` /// table is queried. +#[derive(Debug)] pub struct InformationSchemaProvider { config: InformationSchemaConfig, } @@ -70,20 +71,11 @@ impl InformationSchemaProvider { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct InformationSchemaConfig { catalog_list: Arc, } -impl Debug for InformationSchemaConfig { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InformationSchemaConfig") - // TODO it would be great to print the catalog list here - // but that would require CatalogProviderList to implement Debug - .finish_non_exhaustive() - } -} - impl InformationSchemaConfig { /// Construct the `information_schema.tables` virtual table async fn make_tables( diff --git a/datafusion/core/src/catalog_common/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs index 5b91f963ca244..665ea58c5f755 100644 --- a/datafusion/core/src/catalog_common/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -48,6 +48,7 @@ use object_store::ObjectStore; /// - `s3://host.example.com:3000/data/tpch/customer/_delta_log/` /// /// [`ObjectStore`]: object_store::ObjectStore +#[derive(Debug)] pub struct ListingSchemaProvider { authority: String, path: object_store::path::Path, @@ -135,6 +136,7 @@ impl ListingSchemaProvider { file_type: self.format.clone(), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, diff --git a/datafusion/core/src/catalog_common/memory.rs b/datafusion/core/src/catalog_common/memory.rs index 6d8bddec45473..f25146616891f 100644 --- a/datafusion/core/src/catalog_common/memory.rs +++ b/datafusion/core/src/catalog_common/memory.rs @@ -28,6 +28,7 @@ use std::any::Any; use std::sync::Arc; /// Simple in-memory list of catalogs +#[derive(Debug)] pub struct MemoryCatalogProviderList { /// Collection of catalogs containing schemas and ultimately TableProviders pub catalogs: DashMap>, @@ -71,6 +72,7 @@ impl CatalogProviderList for MemoryCatalogProviderList { } /// Simple in-memory implementation of a catalog. +#[derive(Debug)] pub struct MemoryCatalogProvider { schemas: DashMap>, } @@ -136,6 +138,7 @@ impl CatalogProvider for MemoryCatalogProvider { } /// Simple in-memory implementation of a schema. +#[derive(Debug)] pub struct MemorySchemaProvider { tables: DashMap>, } @@ -248,6 +251,7 @@ mod test { #[test] fn default_register_schema_not_supported() { // mimic a new CatalogProvider and ensure it does not support registering schemas + #[derive(Debug)] struct TestProvider {} impl CatalogProvider for TestProvider { fn as_any(&self) -> &dyn Any { diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs index b8414378862e4..68c78dda48999 100644 --- a/datafusion/core/src/catalog_common/mod.rs +++ b/datafusion/core/src/catalog_common/mod.rs @@ -36,10 +36,6 @@ pub use datafusion_sql::{ResolvedTableReference, TableReference}; use std::collections::BTreeSet; use std::ops::ControlFlow; -/// See [`CatalogProviderList`] -#[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")] -pub trait CatalogList: CatalogProviderList {} - /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. /// This can be used to determine which tables need to be in the catalog for a query to be planned. /// @@ -185,9 +181,7 @@ pub fn resolve_table_references( let _ = s.as_ref().visit(visitor); } DFStatement::CreateExternalTable(table) => { - visitor - .relations - .insert(ObjectName(vec![Ident::from(table.name.as_str())])); + visitor.relations.insert(table.name.clone()); } DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 72b763ce0f2b1..e5d352a63c7a3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,6 +52,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, @@ -66,8 +67,9 @@ use datafusion_catalog::Session; /// Contains options that control how data is /// written out from a DataFrame pub struct DataFrameWriteOptions { - /// Controls if existing data should be overwritten - overwrite: bool, + /// Controls how new data should be written to the table, determining whether + /// to append, overwrite, or replace existing data. + insert_op: InsertOp, /// Controls if all partitions should be coalesced into a single output file /// Generally will have slower performance when set to true. single_file_output: bool, @@ -80,14 +82,15 @@ impl DataFrameWriteOptions { /// Create a new DataFrameWriteOptions with default values pub fn new() -> Self { DataFrameWriteOptions { - overwrite: false, + insert_op: InsertOp::Append, single_file_output: false, partition_by: vec![], } } - /// Set the overwrite option to true or false - pub fn with_overwrite(mut self, overwrite: bool) -> Self { - self.overwrite = overwrite; + + /// Set the insert operation + pub fn with_insert_operation(mut self, insert_op: InsertOp) -> Self { + self.insert_op = insert_op; self } @@ -370,32 +373,9 @@ impl DataFrame { self.select(expr) } - /// Expand each list element of a column to multiple rows. - #[deprecated(since = "37.0.0", note = "use unnest_columns instead")] - pub fn unnest_column(self, column: &str) -> Result { - self.unnest_columns(&[column]) - } - - /// Expand each list element of a column to multiple rows, with - /// behavior controlled by [`UnnestOptions`]. - /// - /// Please see the documentation on [`UnnestOptions`] for more - /// details about the meaning of unnest. - #[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")] - pub fn unnest_column_with_options( - self, - column: &str, - options: UnnestOptions, - ) -> Result { - self.unnest_columns_with_options(&[column], options) - } - /// Expand multiple list/struct columns into a set of rows and new columns. /// - /// See also: - /// - /// 1. [`UnnestOptions`] documentation for the behavior of `unnest` - /// 2. [`Self::unnest_column_with_options`] + /// See also: [`UnnestOptions`] documentation for the behavior of `unnest` /// /// # Example /// ``` @@ -532,9 +512,26 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? .build()?; + let plan = if is_grouping_set { + let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len; + // For grouping sets we do a project to not expose the internal grouping id + let exprs = plan + .schema() + .columns() + .into_iter() + .enumerate() + .filter(|(idx, _)| *idx != grouping_id_pos) + .map(|(_, column)| Expr::Column(column)) + .collect::>(); + LogicalPlanBuilder::from(plan).project(exprs)?.build()? + } else { + plan + }; Ok(DataFrame { session_state: self.session_state, plan, @@ -1525,7 +1522,7 @@ impl DataFrame { self.plan, table_name.to_owned(), &arrow_schema, - write_options.overwrite, + write_options.insert_op, )? .build()?; @@ -1566,10 +1563,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_csv.", + options.insert_op + ))); } let format = if let Some(csv_opts) = writer_options { @@ -1626,10 +1624,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_json.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_json.", + options.insert_op + ))); } let format = if let Some(json_opts) = writer_options { @@ -1942,12 +1941,12 @@ mod tests { use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use arrow::array::{self, Int32Array}; + use arrow::array::Int32Array; use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -1980,8 +1979,8 @@ mod tests { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), - Arc::new(array::StringArray::from(vec!["a"])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), ], ) .unwrap(); @@ -2177,7 +2176,7 @@ mod tests { async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(expr::WindowFunction::new( + let first_row = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), @@ -2601,6 +2600,54 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_union() -> Result<()> { + let df = test_table().await?; + + let df1 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![min(col("c2"))])? + // SELECT `c1` , min(c2) as `result` + .select(vec![col("c1"), min(col("c2")).alias("result")])?; + let df2 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![max(col("c3"))])? + // SELECT `c1` , max(c3) as `result` + .select(vec![col("c1"), max(col("c3")).alias("result")])?; + + let df_union = df1.union(df2)?; + let df = df_union + // GROUP BY `c1` + .aggregate( + vec![col("c1")], + vec![sum(col("result")).alias("sum_result")], + )? + // SELECT `c1`, sum(result) as `sum_result` + .select(vec![(col("c1")), col("sum_result")])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+----+------------+", + "| c1 | sum_result |", + "+----+------------+", + "| a | 84 |", + "| b | 69 |", + "| c | 124 |", + "| d | 126 |", + "| e | 121 |", + "+----+------------+" + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_aggregate_subexpr() -> Result<()> { let df = test_table().await?; @@ -2965,9 +3012,7 @@ mod tests { JoinType::Inner, Some(Expr::Literal(ScalarValue::Null)), )?; - let expected_plan = "CrossJoin:\ - \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ - \n TableScan: b projection=[c1]"; + let expected_plan = "EmptyRelation"; assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?)); // JOIN ON expression must be boolean type @@ -3375,52 +3420,6 @@ mod tests { Ok(()) } - // Table 't1' self join - // Supplementary test of issue: https://github.com/apache/datafusion/issues/7790 - #[tokio::test] - async fn with_column_self_join() -> Result<()> { - let df = test_table().await?.select_columns(&["c1"])?; - let ctx = SessionContext::new(); - - ctx.register_table("t1", df.into_view())?; - - let df = ctx - .table("t1") - .await? - .join( - ctx.table("t1").await?, - JoinType::Inner, - &["c1"], - &["c1"], - None, - )? - .sort(vec![ - // make the test deterministic - col("t1.c1").sort(true, true), - ])? - .limit(0, Some(1))?; - - let df_results = df.clone().collect().await?; - assert_batches_sorted_eq!( - [ - "+----+----+", - "| c1 | c1 |", - "+----+----+", - "| a | a |", - "+----+----+", - ], - &df_results - ); - - let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err(); - let expected_err = "Error during planning: Projections require unique expression names \ - but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \ - Consider aliasing (\"AS\") one of them."; - assert_eq!(actual_err.strip_backtrace(), expected_err); - - Ok(()) - } - #[tokio::test] async fn with_column_renamed() -> Result<()> { let df = test_table() @@ -3571,11 +3570,10 @@ mod tests { #[tokio::test] async fn with_column_renamed_case_sensitive() -> Result<()> { - let config = - SessionConfig::from_string_hash_map(&std::collections::HashMap::from([( - "datafusion.sql_parser.enable_ident_normalization".to_owned(), - "false".to_owned(), - )]))?; + let config = SessionConfig::from_string_hash_map(&HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; let ctx = SessionContext::new_with_config(config); let name = "aggregate_test_100"; register_aggregate_csv(&ctx, name).await?; @@ -3647,7 +3645,7 @@ mod tests { #[tokio::test] async fn row_writer_resize_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "column_1", DataType::Utf8, false, @@ -3656,7 +3654,7 @@ mod tests { let data = RecordBatch::try_new( schema, vec![ - Arc::new(arrow::array::StringArray::from(vec![ + Arc::new(StringArray::from(vec![ Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), ])) diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 66974e37f4533..f90b35fde6baf 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -26,6 +26,7 @@ use super::{ }; use datafusion_common::config::TableParquetOptions; +use datafusion_expr::dml::InsertOp; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -57,10 +58,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_parquet.", + options.insert_op + ))); } let format = if let Some(parquet_opts) = writer_options { diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 3a5d50bba07fc..9f089c7c0cea8 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -206,7 +206,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn build_primitive_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef where T: ArrowNumericType + Resolver, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { Arc::new( rows.iter() @@ -354,7 +354,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let builder = builder .as_any_mut() .downcast_mut::>() - .ok_or_else(||ArrowError::SchemaError( + .ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -369,7 +369,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { builder.append(true); } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -402,7 +402,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { col_name: &str, ) -> ArrowResult where - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, T: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: StringDictionaryBuilder = @@ -453,12 +453,10 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt64 => { self.build_dictionary_array::(rows, col_name) } - _ => Err(ArrowError::SchemaError( - "unsupported dictionary key type".to_string(), - )), + _ => Err(SchemaError("unsupported dictionary key type".to_string())), } } else { - Err(ArrowError::SchemaError( + Err(SchemaError( "dictionary types other than UTF-8 not yet supported".to_string(), )) } @@ -532,7 +530,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt32 => self.read_primitive_list_values::(rows), DataType::UInt64 => self.read_primitive_list_values::(rows), DataType::Float16 => { - return Err(ArrowError::SchemaError("Float16 not supported".to_string())) + return Err(SchemaError("Float16 not supported".to_string())) } DataType::Float32 => self.read_primitive_list_values::(rows), DataType::Float64 => self.read_primitive_list_values::(rows), @@ -541,7 +539,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { - return Err(ArrowError::SchemaError( + return Err(SchemaError( "Temporal types are not yet supported, see ARROW-4803".to_string(), )) } @@ -573,7 +571,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { // extract list values, with non-lists converted to Value::Null let array_item_count = rows .iter() - .map(|row| match row { + .map(|row| match maybe_resolve_union(row) { Value::Array(values) => values.len(), _ => 1, }) @@ -623,7 +621,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .unwrap() } datatype => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "Nested list of {datatype:?} not supported" ))); } @@ -737,7 +735,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time64" ))) } @@ -751,7 +749,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time32" ))) } @@ -854,7 +852,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { make_array(data) } _ => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "type {:?} not supported", field.data_type() ))) @@ -870,7 +868,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn read_primitive_list_values(&self, rows: &[&Value]) -> ArrayData where T: ArrowPrimitiveType + ArrowNumericType, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { let values = rows .iter() @@ -970,7 +968,7 @@ fn resolve_u8(v: &Value) -> AvroResult { other => Err(AvroError::GetU8(other.into())), }?; if let Value::Int(n) = int { - if n >= 0 && n <= std::convert::From::from(u8::MAX) { + if n >= 0 && n <= From::from(u8::MAX) { return Ok(n as u8); } } @@ -1048,7 +1046,7 @@ fn maybe_resolve_union(value: &Value) -> &Value { impl Resolver for N where N: ArrowNumericType, - N::Native: num_traits::cast::NumCast, + N::Native: NumCast, { fn resolve(value: &Value) -> Option { let value = maybe_resolve_union(value); @@ -1643,6 +1641,93 @@ mod test { assert_batches_eq!(expected, &[batch]); } + #[test] + fn test_avro_nullable_struct_array() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "array", + "items": { + "type": [ + "null", + { + "type": "record", + "name": "Item", + "fields": [ + { + "name": "id", + "type": "long" + } + ] + } + ] + } + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let jv1 = serde_json::json!({ + "col1": [ + { + "id": 234 + }, + { + "id": 345 + } + ] + }); + let r1 = apache_avro::to_value(jv1) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + for _i in 0..5 { + w.append(r1.clone()).unwrap(); + } + w.append(r2).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(20) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+------------------------+", + "| col1 |", + "+------------------------+", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| |", + "+------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/datasource/avro_to_arrow/mod.rs b/datafusion/core/src/datasource/avro_to_arrow/mod.rs index c59078c89dd00..71184a78c96f5 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/mod.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/mod.rs @@ -39,7 +39,7 @@ use std::io::Read; pub fn read_avro_schema_from_reader(reader: &mut R) -> Result { let avro_reader = apache_avro::Reader::new(reader)?; let schema = avro_reader.writer_schema(); - schema::to_arrow_schema(schema) + to_arrow_schema(schema) } #[cfg(not(feature = "avro"))] diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index a95f3abb939b2..6654d0871c3f6 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -30,7 +30,7 @@ use crate::error::Result; use crate::execution::context::SessionState; /// [DynamicListTableFactory] is a factory that can create a [ListingTable] from the given url. -#[derive(Default)] +#[derive(Default, Debug)] pub struct DynamicListTableFactory { /// The session store that contains the current session. session_store: SessionStore, @@ -69,11 +69,18 @@ impl UrlTableFactory for DynamicListTableFactory { .ok_or_else(|| plan_datafusion_err!("get current SessionStore error"))?; match ListingTableConfig::new(table_url.clone()) - .infer(state) + .infer_options(state) .await { - Ok(cfg) => ListingTable::try_new(cfg) - .map(|table| Some(Arc::new(table) as Arc)), + Ok(cfg) => { + let cfg = cfg + .infer_partitions_from_path(state) + .await? + .infer_schema(state) + .await?; + ListingTable::try_new(cfg) + .map(|table| Some(Arc::new(table) as Arc)) + } Err(_) => Ok(None), } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6ee4280956e87..c10ebbd6c9eab 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -47,6 +47,7 @@ use datafusion_common::{ not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -181,7 +182,7 @@ impl FileFormat for ArrowFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 99e8f13776fcc..3cb5ae4f85cad 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -46,6 +46,7 @@ use datafusion_common::{ exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, }; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; @@ -77,7 +78,7 @@ impl CsvFormatFactory { } } -impl fmt::Debug for CsvFormatFactory { +impl Debug for CsvFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CsvFormatFactory") .field("options", &self.options) @@ -382,7 +383,7 @@ impl FileFormat for CsvFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for CSV"); } @@ -770,7 +771,7 @@ mod tests { "c7: Int64", "c8: Int64", "c9: Int64", - "c10: Int64", + "c10: Utf8", "c11: Float64", "c12: Float64", "c13: Utf8" @@ -906,7 +907,7 @@ mod tests { Field::new("c7", DataType::Int64, true), Field::new("c8", DataType::Int64, true), Field::new("c9", DataType::Int64, true), - Field::new("c10", DataType::Int64, true), + Field::new("c10", DataType::Utf8, true), Field::new("c11", DataType::Float64, true), Field::new("c12", DataType::Float64, true), Field::new("c13", DataType::Utf8, true), @@ -967,7 +968,7 @@ mod tests { limit: Option, has_header: bool, ) -> Result> { - let root = format!("{}/csv", crate::test_util::arrow_test_data()); + let root = format!("{}/csv", arrow_test_data()); let format = CsvFormat::default().with_has_header(has_header); scan_format(state, &format, &root, file_name, projection, limit).await } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 4471d7d6cb31c..fd97da52165b9 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -46,6 +46,7 @@ use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; @@ -117,7 +118,7 @@ impl GetExt for JsonFormatFactory { } } -impl fmt::Debug for JsonFormatFactory { +impl Debug for JsonFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("JsonFormatFactory") .field("options", &self.options) @@ -252,7 +253,7 @@ impl FileFormat for JsonFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Json"); } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 60f2b2dcefa93..24f1111517d2e 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -42,7 +42,7 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_expr::Expr; @@ -79,7 +79,7 @@ pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// /// [`TableProvider`]: crate::catalog::TableProvider #[async_trait] -pub trait FileFormat: Send + Sync + fmt::Debug { +pub trait FileFormat: Send + Sync + Debug { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -224,7 +224,7 @@ pub fn format_as_file_type( /// downcasted to a [DefaultFileType]. pub fn file_type_to_format( file_type: &Arc, -) -> datafusion_common::Result> { +) -> Result> { match file_type .as_ref() .as_any() @@ -235,22 +235,26 @@ pub fn file_type_to_format( } } +/// Create a new field with the specified data type, copying the other +/// properties from the input field +fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { + Arc::new(field.as_ref().clone().with_data_type(new_type)) +} + /// Transform a schema to use view types for Utf8 and Binary +/// +/// See [parquet::ParquetFormat::force_view_types] for details pub fn transform_schema_to_view(schema: &Schema) -> Schema { let transformed_fields: Vec> = schema .fields .iter() .map(|field| match field.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Arc::new(Field::new( - field.name(), - DataType::Utf8View, - field.is_nullable(), - )), - DataType::Binary | DataType::LargeBinary => Arc::new(Field::new( - field.name(), - DataType::BinaryView, - field.is_nullable(), - )), + DataType::Utf8 | DataType::LargeUtf8 => { + field_with_new_type(field, DataType::Utf8View) + } + DataType::Binary | DataType::LargeBinary => { + field_with_new_type(field, DataType::BinaryView) + } _ => field.clone(), }) .collect(); @@ -276,6 +280,7 @@ pub(crate) fn coerce_file_schema_to_view_type( (f.name(), dt) }) .collect(); + if !transform { return None; } @@ -285,14 +290,13 @@ pub(crate) fn coerce_file_schema_to_view_type( .iter() .map( |field| match (table_fields.get(field.name()), field.data_type()) { - (Some(DataType::Utf8View), DataType::Utf8) - | (Some(DataType::Utf8View), DataType::LargeUtf8) => Arc::new( - Field::new(field.name(), DataType::Utf8View, field.is_nullable()), - ), - (Some(DataType::BinaryView), DataType::Binary) - | (Some(DataType::BinaryView), DataType::LargeBinary) => Arc::new( - Field::new(field.name(), DataType::BinaryView, field.is_nullable()), - ), + (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { + field_with_new_type(field, DataType::Utf8View) + } + ( + Some(DataType::BinaryView), + DataType::Binary | DataType::LargeBinary, + ) => field_with_new_type(field, DataType::BinaryView), _ => field.clone(), }, ) @@ -304,6 +308,78 @@ pub(crate) fn coerce_file_schema_to_view_type( )) } +/// Transform a schema so that any binary types are strings +pub fn transform_binary_to_string(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Binary => field_with_new_type(field, DataType::Utf8), + DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), + DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), + _ => field.clone(), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// If the table schema uses a string type, coerce the file schema to use a string type. +/// +/// See [parquet::ParquetFormat::binary_as_string] for details +pub(crate) fn coerce_file_schema_to_string_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| (f.name(), f.data_type())) + .collect(); + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + // table schema uses string type, coerce the file schema to use string type + ( + Some(DataType::Utf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8) + } + // table schema uses large string type, coerce the file schema to use large string type + ( + Some(DataType::LargeUtf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::LargeUtf8) + } + // table schema uses string view type, coerce the file schema to use view type + ( + Some(DataType::Utf8View), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8View) + } + _ => field.clone(), + }, + ) + .collect(); + + if !transform { + None + } else { + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) + } +} + #[cfg(test)] pub(crate) mod test_util { use std::ops::Range; @@ -371,8 +447,8 @@ pub(crate) mod test_util { iterations_detected: Arc>, } - impl std::fmt::Display for VariableStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl Display for VariableStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "VariableStream") } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 35296b0d79076..9153e71a5c267 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -20,13 +20,15 @@ use std::any::Any; use std::fmt; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; use super::{ - coerce_file_schema_to_view_type, transform_schema_to_view, FileFormat, - FileFormatFactory, FilePushdownSupport, FileScanConfig, + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, + transform_binary_to_string, transform_schema_to_view, FileFormat, FileFormatFactory, + FilePushdownSupport, FileScanConfig, }; use crate::arrow::array::RecordBatch; use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; @@ -47,19 +49,20 @@ use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + internal_datafusion_err, not_impl_err, DataFusionError, GetExt, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; -use bytes::{BufMut, BytesMut}; +use bytes::Bytes; use hashbrown::HashMap; use log::debug; use object_store::buffered::BufWriter; @@ -70,8 +73,7 @@ use parquet::arrow::arrow_writer::{ use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, }; -use parquet::file::footer::{decode_footer, decode_metadata}; -use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; +use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; use parquet::file::properties::WriterProperties; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; @@ -83,10 +85,13 @@ use crate::datasource::physical_plan::parquet::{ can_expr_be_pushed_down_with_schemas, ParquetExecBuilder, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; -use futures::{StreamExt, TryStreamExt}; +use futures::future::BoxFuture; +use futures::{FutureExt, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::async_reader::MetadataFetch; +use parquet::errors::ParquetError; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. @@ -160,7 +165,7 @@ impl GetExt for ParquetFormatFactory { } } -impl fmt::Debug for ParquetFormatFactory { +impl Debug for ParquetFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ParquetFormatFactory") .field("ParquetFormatFactory", &self.options) @@ -249,13 +254,29 @@ impl ParquetFormat { self.options.global.schema_force_view_types } - /// If true, will use view types (StringView and BinaryView). - /// - /// Refer to [`Self::force_view_types`]. + /// If true, will use view types. See [`Self::force_view_types`] for details pub fn with_force_view_types(mut self, use_views: bool) -> Self { self.options.global.schema_force_view_types = use_views; self } + + /// Return `true` if binary types will be read as strings. + /// + /// If this returns true, DataFusion will instruct the parquet reader + /// to read binary columns such as `Binary` or `BinaryView` as the + /// corresponding string type such as `Utf8` or `LargeUtf8`. + /// The parquet reader has special optimizations for `Utf8` and `LargeUtf8` + /// validation, and such queries are significantly faster than reading + /// binary columns and then casting to string columns. + pub fn binary_as_string(&self) -> bool { + self.options.global.binary_as_string + } + + /// If true, will read binary types as strings. See [`Self::binary_as_string`] for details + pub fn with_binary_as_string(mut self, binary_as_string: bool) -> Self { + self.options.global.binary_as_string = binary_as_string; + self + } } /// Clears all metadata (Schema level and field level) on an iterator @@ -346,6 +367,12 @@ impl FileFormat for ParquetFormat { Schema::try_merge(schemas) }?; + let schema = if self.binary_as_string() { + transform_binary_to_string(&schema) + } else { + schema + }; + let schema = if self.force_view_types() { transform_schema_to_view(&schema) } else { @@ -403,7 +430,7 @@ impl FileFormat for ParquetFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } @@ -440,6 +467,33 @@ impl FileFormat for ParquetFormat { } } +/// [`MetadataFetch`] adapter for reading bytes from an [`ObjectStore`] +struct ObjectStoreFetch<'a> { + store: &'a dyn ObjectStore, + meta: &'a ObjectMeta, +} + +impl<'a> ObjectStoreFetch<'a> { + fn new(store: &'a dyn ObjectStore, meta: &'a ObjectMeta) -> Self { + Self { store, meta } + } +} + +impl<'a> MetadataFetch for ObjectStoreFetch<'a> { + fn fetch( + &mut self, + range: Range, + ) -> BoxFuture<'_, Result> { + async { + self.store + .get_range(&self.meta.location, range) + .await + .map_err(ParquetError::from) + } + .boxed() + } +} + /// Fetches parquet metadata from ObjectStore for given object /// /// This component is a subject to **change** in near future and is exposed for low level integrations @@ -451,57 +505,14 @@ pub async fn fetch_parquet_metadata( meta: &ObjectMeta, size_hint: Option, ) -> Result { - if meta.size < 8 { - return exec_err!("file size of {} is less than footer", meta.size); - } + let file_size = meta.size; + let fetch = ObjectStoreFetch::new(store, meta); - // If a size hint is provided, read more than the minimum size - // to try and avoid a second fetch. - let footer_start = if let Some(size_hint) = size_hint { - meta.size.saturating_sub(size_hint) - } else { - meta.size - 8 - }; - - let suffix = store - .get_range(&meta.location, footer_start..meta.size) - .await?; - - let suffix_len = suffix.len(); - - let mut footer = [0; 8]; - footer.copy_from_slice(&suffix[suffix_len - 8..suffix_len]); - - let length = decode_footer(&footer)?; - - if meta.size < length + 8 { - return exec_err!( - "file size of {} is less than footer + metadata {}", - meta.size, - length + 8 - ); - } - - // Did not fetch the entire file metadata in the initial read, need to make a second request - if length > suffix_len - 8 { - let metadata_start = meta.size - length - 8; - let remaining_metadata = store - .get_range(&meta.location, metadata_start..footer_start) - .await?; - - let mut metadata = BytesMut::with_capacity(length); - - metadata.put(remaining_metadata.as_ref()); - metadata.put(&suffix[..suffix_len - 8]); - - Ok(decode_metadata(metadata.as_ref())?) - } else { - let metadata_start = meta.size - length - 8; - - Ok(decode_metadata( - &suffix[metadata_start - footer_start..suffix_len - 8], - )?) - } + ParquetMetaDataReader::new() + .with_prefetch_hint(size_hint) + .load_and_finish(fetch, file_size) + .await + .map_err(DataFusionError::from) } /// Read and parse the schema of the Parquet file at location `path` @@ -564,6 +575,10 @@ pub fn statistics_from_parquet_meta_calc( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; + if let Some(merged) = coerce_file_schema_to_string_type(&table_schema, &file_schema) { + file_schema = merged; + } + if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &file_schema) { file_schema = merged; } @@ -723,13 +738,14 @@ impl ParquetSink { .iter() .map(|(s, _)| s) .collect(); - Arc::new(Schema::new( + Arc::new(Schema::new_with_metadata( schema .fields() .iter() .filter(|f| !partition_names.contains(&f.name())) .map(|f| (**f).clone()) .collect::>(), + schema.metadata().clone(), )) } else { self.config.output_schema().clone() @@ -1423,7 +1439,7 @@ mod tests { } impl Display for RequestCountingObjectStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "RequestCounting({})", self.inner) } } @@ -1691,7 +1707,7 @@ mod tests { let null_utf8 = if force_views { ScalarValue::Utf8View(None) } else { - ScalarValue::Utf8(None) + Utf8(None) }; // Fetch statistics for first file @@ -1704,7 +1720,7 @@ mod tests { let expected_type = if force_views { ScalarValue::Utf8View } else { - ScalarValue::Utf8 + Utf8 }; assert_eq!( c1_stats.max_value, @@ -2269,7 +2285,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2364,7 +2380,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2447,7 +2463,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 427b28db40301..1746ffef8282b 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -280,9 +280,8 @@ async fn hive_style_partitions_demuxer( Some(part_tx) => part_tx, None => { // Create channel for previously unseen distinct partition key and notify consumer of new file - let (part_tx, part_rx) = tokio::sync::mpsc::channel::( - max_buffered_recordbatches, - ); + let (part_tx, part_rx) = + mpsc::channel::(max_buffered_recordbatches); let file_path = compute_hive_style_file_path( &part_key, &partition_by, diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs index 14bbc431f9739..37ce59f8207b2 100644 --- a/datafusion/core/src/datasource/function.rs +++ b/datafusion/core/src/datasource/function.rs @@ -22,15 +22,17 @@ use super::TableProvider; use datafusion_common::Result; use datafusion_expr::Expr; +use std::fmt::Debug; use std::sync::Arc; /// A trait for table function implementations -pub trait TableFunctionImpl: Sync + Send { +pub trait TableFunctionImpl: Debug + Sync + Send { /// Create a table provider fn call(&self, args: &[Expr]) -> Result>; } /// A table that uses a function to generate data +#[derive(Debug)] pub struct TableFunction { /// Name of the table function name: String, diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 72d7277d6ae26..47012f777ad1e 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -24,6 +24,7 @@ use std::sync::Arc; use super::ListingTableUrl; use super::PartitionedFile; use crate::execution::context::SessionState; +use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{BinaryExpr, Operator}; @@ -285,25 +286,20 @@ async fn prune_partitions( let props = ExecutionProps::new(); // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Option { - let expr = create_physical_expr(filter, &df_schema, &props).ok()?; - expr.evaluate(&batch) - .ok()? - .into_array(partitions.len()) - .ok() + let do_filter = |filter| -> Result { + let expr = create_physical_expr(filter, &df_schema, &props)?; + expr.evaluate(&batch)?.into_array(partitions.len()) }; - //.Compute the conjunction of the filters, ignoring errors + //.Compute the conjunction of the filters let mask = filters .iter() - .fold(None, |acc, filter| match (acc, do_filter(filter)) { - (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), - (None, Some(r)) => Some(r.as_boolean().clone()), - (r, None) => r, - }); + .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) + .reduce(|a, b| Ok(and(&a?, &b?)?)); let mask = match mask { - Some(mask) => mask, + Some(Ok(mask)) => mask, + Some(Err(err)) => return Err(err), None => return Ok(partitions), }; @@ -401,8 +397,8 @@ fn evaluate_partition_prefix<'a>( /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. -/// `filters` might contain expressions that can be resolved only at the -/// file level (e.g. Parquet row group pruning). +/// `filters` should only contain expressions that can be evaluated +/// using only the partition columns. pub async fn pruned_partition_list<'a>( ctx: &'a SessionState, store: &'a dyn ObjectStore, @@ -413,6 +409,12 @@ pub async fn pruned_partition_list<'a>( ) -> Result>> { // if no partition col => simply list all the files if partition_cols.is_empty() { + if !filters.is_empty() { + return internal_err!( + "Got partition filters for unpartitioned table {}", + table_path + ); + } return Ok(Box::pin( table_path .list_all_files(ctx, store, file_extension) @@ -631,13 +633,11 @@ mod tests { ]); let filter1 = Expr::eq(col("part1"), lit("p1v2")); let filter2 = Expr::eq(col("part2"), lit("p2v1")); - // filter3 cannot be resolved at partition pruning - let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2, filter3], + &[filter1, filter2], ".parquet", &[ (String::from("part1"), DataType::Utf8), diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a35fddeb0337..ea2e098ef14ec 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -33,7 +33,8 @@ use crate::datasource::{ }; use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{config_err, DataFusionError, Result}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; @@ -191,6 +192,38 @@ impl ListingTableConfig { pub async fn infer(self, state: &SessionState) -> Result { self.infer_options(state).await?.infer_schema(state).await } + + /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + pub async fn infer_partitions_from_path(self, state: &SessionState) -> Result { + match self.options { + Some(options) => { + let Some(url) = self.table_paths.first() else { + return config_err!("No table path found"); + }; + let partitions = options + .infer_partitions(state, url) + .await? + .into_iter() + .map(|col_name| { + ( + col_name, + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(); + let options = options.with_table_partition_cols(partitions); + Ok(Self { + table_paths: self.table_paths, + file_schema: self.file_schema, + options: Some(options), + }) + } + None => config_err!("No `ListingOptions` set for inferring schema"), + } + } } /// Options for creating a [`ListingTable`] @@ -504,7 +537,7 @@ impl ListingOptions { /// Infer the partitioning at the given path on the provided object store. /// For performance reasons, it doesn't read all the files on disk /// and therefore may fail to detect invalid partitioning. - async fn infer_partitions( + pub(crate) async fn infer_partitions( &self, state: &SessionState, table_path: &ListingTableUrl, @@ -686,10 +719,16 @@ impl ListingTable { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + let table = Self { table_paths: config.table_paths, file_schema, - table_schema: Arc::new(builder.finish()), + table_schema, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), @@ -749,6 +788,16 @@ impl ListingTable { } } +// Expressions can be used for parttion pruning if they can be evaluated using +// only the partiton columns and there are partition columns. +fn can_be_evaluted_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + #[async_trait] impl TableProvider for ListingTable { fn as_any(&self) -> &dyn Any { @@ -774,10 +823,28 @@ impl TableProvider for ListingTable { filters: &[Expr], limit: Option, ) -> Result> { + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + }); // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? let session_state = state.as_any().downcast_ref::().unwrap(); let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(session_state, filters, limit) + .list_files_for_scan(session_state, &partition_filters, limit) .await?; // if no files need to be read, return an `EmptyExec` @@ -813,28 +880,6 @@ impl TableProvider for ListingTable { None => {} // no ordering required }; - // extract types of partition columns - let table_partition_cols = self - .options - .table_partition_cols - .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) - .collect::>>()?; - - // If the filters can be resolved using only partition cols, there is no need to - // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated - let table_partition_col_names = table_partition_cols - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - let filters = filters - .iter() - .filter(|filter| { - !expr_applicable_for_cols(&table_partition_col_names, filter) - }) - .cloned() - .collect::>(); - let filters = conjunction(filters.to_vec()) .map(|expr| -> Result<_> { // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. @@ -875,18 +920,17 @@ impl TableProvider for ListingTable { &self, filters: &[&Expr], ) -> Result> { + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); filters .iter() .map(|filter| { - if expr_applicable_for_cols( - &self - .options - .table_partition_cols - .iter() - .map(|col| col.0.as_str()) - .collect::>(), - filter, - ) { + if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) + { // if filter can be handled by partition pruning, it is exact return Ok(TableProviderFilterPushDown::Exact); } @@ -916,7 +960,7 @@ impl TableProvider for ListingTable { &self, state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. if !self @@ -975,7 +1019,7 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - overwrite, + insert_op, keep_partition_by_columns, }; @@ -1990,7 +2034,8 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index fed63ec12b496..581d88d25884a 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -91,7 +91,7 @@ impl TableProviderFactory for ListingTableFactory { .field_with_name(col) .map_err(|e| arrow_datafusion_err!(e)) }) - .collect::>>()? + .collect::>>()? .into_iter() .map(|f| (f.name().to_owned(), f.data_type().to_owned())) .collect(); @@ -197,6 +197,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, @@ -236,6 +237,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 70f3c36b81e19..3c2d1b0205d6e 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -37,13 +37,14 @@ use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_catalog::Session; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::SortExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_expr::SortExpr; use futures::StreamExt; use log::debug; use parking_lot::Mutex; @@ -240,7 +241,7 @@ impl TableProvider for MemTable { ) }) .collect::>>()?; - exec = exec.with_sort_information(file_sort_order); + exec = exec.try_with_sort_information(file_sort_order)?; } Ok(Arc::new(exec)) @@ -262,7 +263,7 @@ impl TableProvider for MemTable { &self, _state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // If we are inserting into the table, any sort order may be messed up so reset it here *self.sort_order.lock() = vec![]; @@ -289,8 +290,8 @@ impl TableProvider for MemTable { .collect::>() ); } - if overwrite { - return not_impl_err!("Overwrite not implemented for MemoryTable yet"); + if insert_op != InsertOp::Append { + return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } let sink = Arc::new(MemSink::new(self.batches.clone())); Ok(Arc::new(DataSinkExec::new( @@ -638,7 +639,8 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 6cd1864deb1d4..5beffc3b0581d 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -1216,7 +1216,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\n1,2\n3,4"); + let data = Bytes::from("a,b\n1,2\n3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); @@ -1247,7 +1247,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\r1,2\r3,4"); + let data = Bytes::from("a,b\r1,2\r3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 2c438e8b0e78b..96c0e452e29e4 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,7 +19,8 @@ //! file sources. use std::{ - borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, + sync::Arc, vec, }; use super::{get_projected_output_ordering, statistics::MinMaxStatistics}; @@ -248,9 +249,10 @@ impl FileScanConfig { column_statistics: table_cols_stats, }; - let projected_schema = Arc::new( - Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), - ); + let projected_schema = Arc::new(Schema::new_with_metadata( + table_fields, + self.file_schema.metadata().clone(), + )); let projected_output_ordering = get_projected_output_ordering(self, &projected_schema); @@ -281,7 +283,12 @@ impl FileScanConfig { fields.map_or_else( || Arc::clone(&self.file_schema), - |f| Arc::new(Schema::new(f).with_metadata(self.file_schema.metadata.clone())), + |f| { + Arc::new(Schema::new_with_metadata( + f, + self.file_schema.metadata.clone(), + )) + }, ) } @@ -491,7 +498,7 @@ impl ZeroBufferGenerator where T: ArrowNativeType, { - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = size_of::(); fn get_buffer(&mut self, n_vals: usize) -> Buffer { match &mut self.cache { diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4018b3bb2920f..407a3b74f79f2 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -36,6 +36,7 @@ pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactor pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; +use datafusion_expr::dml::InsertOp; pub use file_groups::FileGroupPartitioner; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, @@ -83,8 +84,9 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// Controls whether existing data should be overwritten by this sink - pub overwrite: bool, + /// Controls how new data should be written to the file, determining whether + /// to append to, overwrite, or replace records in existing files. + pub insert_op: InsertOp, /// Controls whether partition columns are kept for the file pub keep_partition_by_columns: bool, } @@ -761,7 +763,7 @@ mod tests { /// create a PartitionedFile for testing fn partitioned_file(path: &str) -> PartitionedFile { let object_meta = ObjectMeta { - location: object_store::path::Path::parse(path).unwrap(), + location: Path::parse(path).unwrap(), last_modified: Utc::now(), size: 42, e_tag: None, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 6afb66cc7c02e..059f86ce110f4 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -166,6 +166,33 @@ pub use writer::plan_to_parquet; /// [`RowFilter`]: parquet::arrow::arrow_reader::RowFilter /// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// +/// # Example: rewriting `ParquetExec` +/// +/// You can modify a `ParquetExec` using [`ParquetExecBuilder`], for example +/// to change files or add a predicate. +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # fn parquet_exec() -> ParquetExec { unimplemented!() } +/// // Split a single ParquetExec into multiple ParquetExecs, one for each file +/// let exec = parquet_exec(); +/// let existing_file_groups = &exec.base_config().file_groups; +/// let new_execs = existing_file_groups +/// .iter() +/// .map(|file_group| { +/// // create a new exec by copying the existing exec into a builder +/// let new_exec = exec.clone() +/// .into_builder() +/// .with_file_groups(vec![file_group.clone()]) +/// .build(); +/// new_exec +/// }) +/// .collect::>(); +/// ``` +/// /// # Implementing External Indexes /// /// It is possible to restrict the row groups and selections within those row @@ -257,6 +284,12 @@ pub struct ParquetExec { schema_adapter_factory: Option>, } +impl From for ParquetExecBuilder { + fn from(exec: ParquetExec) -> Self { + exec.into_builder() + } +} + /// [`ParquetExecBuilder`], builder for [`ParquetExec`]. /// /// See example on [`ParquetExec`]. @@ -291,6 +324,12 @@ impl ParquetExecBuilder { } } + /// Update the list of files groups to read + pub fn with_file_groups(mut self, file_groups: Vec>) -> Self { + self.file_scan_config.file_groups = file_groups; + self + } + /// Set the filter predicate when reading. /// /// See the "Predicate Pushdown" section of the [`ParquetExec`] documenation @@ -459,6 +498,34 @@ impl ParquetExec { ParquetExecBuilder::new(file_scan_config) } + /// Convert this `ParquetExec` into a builder for modification + pub fn into_builder(self) -> ParquetExecBuilder { + // list out fields so it is clear what is being dropped + // (note the fields which are dropped are re-created as part of calling + // `build` on the builder) + let Self { + base_config, + projected_statistics: _, + metrics: _, + predicate, + pruning_predicate: _, + page_pruning_predicate: _, + metadata_size_hint, + parquet_file_reader_factory, + cache: _, + table_parquet_options, + schema_adapter_factory, + } = self; + ParquetExecBuilder { + file_scan_config: base_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, + } + } + /// [`FileScanConfig`] that controls this scan (such as which files to read) pub fn base_config(&self) -> &FileScanConfig { &self.base_config @@ -479,9 +546,15 @@ impl ParquetExec { self.pruning_predicate.as_ref() } + /// return the optional file reader factory + pub fn parquet_file_reader_factory( + &self, + ) -> Option<&Arc> { + self.parquet_file_reader_factory.as_ref() + } + /// Optional user defined parquet file reader factory. /// - /// See documentation on [`ParquetExecBuilder::with_parquet_file_reader_factory`] pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -490,6 +563,11 @@ impl ParquetExec { self } + /// return the optional schema adapter factory + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + /// Optional schema adapter factory. /// /// See documentation on [`ParquetExecBuilder::with_schema_adapter_factory`] @@ -586,7 +664,14 @@ impl ParquetExec { ) } - fn with_file_groups(mut self, file_groups: Vec>) -> Self { + /// Updates the file groups to read and recalculates the output partitioning + /// + /// Note this function does not update statistics or other properties + /// that depend on the file groups. + fn with_file_groups_and_update_partitioning( + mut self, + file_groups: Vec>, + ) -> Self { self.base_config.file_groups = file_groups; // Changing file groups may invalidate output partitioning. Update it also let output_partitioning = Self::output_partitioning_helper(&self.base_config); @@ -679,7 +764,8 @@ impl ExecutionPlan for ParquetExec { let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - new_plan = new_plan.with_file_groups(repartitioned_file_groups); + new_plan = new_plan + .with_file_groups_and_update_partitioning(repartitioned_file_groups); } Ok(Some(Arc::new(new_plan))) } @@ -2141,7 +2227,7 @@ mod tests { // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; - std::fs::create_dir(&out_dir).unwrap(); + fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; let schema: Schema = df.schema().into(); // Register a listing table - this will use all files in the directory as data sources diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index a818a88502842..4990cb4dd735d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,7 +17,9 @@ //! [`ParquetOpener`] for opening Parquet files -use crate::datasource::file_format::coerce_file_schema_to_view_type; +use crate::datasource::file_format::{ + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, +}; use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ @@ -80,7 +82,7 @@ pub(super) struct ParquetOpener { } impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> datafusion_common::Result { + fn open(&self, file_meta: FileMeta) -> Result { let file_range = file_meta.range.clone(); let extensions = file_meta.extensions.clone(); let file_name = file_meta.location().to_string(); @@ -121,7 +123,14 @@ impl FileOpener for ParquetOpener { let mut metadata_timer = file_metrics.metadata_load_time.timer(); let metadata = ArrowReaderMetadata::load_async(&mut reader, options.clone()).await?; - let mut schema = metadata.schema().clone(); + let mut schema = Arc::clone(metadata.schema()); + + if let Some(merged) = + coerce_file_schema_to_string_type(&table_schema, &schema) + { + schema = Arc::new(merged); + } + // read with view types if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &schema) { @@ -130,16 +139,16 @@ impl FileOpener for ParquetOpener { let options = ArrowReaderOptions::new() .with_page_index(enable_page_index) - .with_schema(schema.clone()); + .with_schema(Arc::clone(&schema)); let metadata = - ArrowReaderMetadata::try_new(metadata.metadata().clone(), options)?; + ArrowReaderMetadata::try_new(Arc::clone(metadata.metadata()), options)?; metadata_timer.stop(); let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata); - let file_schema = builder.schema().clone(); + let file_schema = Arc::clone(builder.schema()); let (schema_mapping, adapted_projections) = schema_adapter.map_schema(&file_schema)?; @@ -177,7 +186,7 @@ impl FileOpener for ParquetOpener { // Determine which row groups to actually read. The idea is to skip // as many row groups as possible based on the metadata and query - let file_metadata = builder.metadata().clone(); + let file_metadata = Arc::clone(builder.metadata()); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let rg_metadata = file_metadata.row_groups(); // track which row groups to actually read diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index a1d74cb54355e..7406676652f66 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -779,11 +779,8 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 2), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 2), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -849,11 +846,8 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 0), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 0), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -863,7 +857,7 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); @@ -947,7 +941,7 @@ mod tests { // INT64: c1 < 5, the c1 is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) @@ -1005,7 +999,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) @@ -1018,7 +1012,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1083,7 +1077,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) @@ -1096,7 +1090,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/statistics.rs b/datafusion/core/src/datasource/physical_plan/statistics.rs index e1c61ec1a7129..3ca3ba89f4d97 100644 --- a/datafusion/core/src/datasource/physical_plan/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/statistics.rs @@ -278,13 +278,9 @@ impl MinMaxStatistics { fn sort_columns_from_physical_sort_exprs( sort_order: &[PhysicalSortExpr], -) -> Option> { +) -> Option> { sort_order .iter() - .map(|expr| { - expr.expr - .as_any() - .downcast_ref::() - }) + .map(|expr| expr.expr.as_any().downcast_ref::()) .collect::>>() } diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index fdf3381758a48..5ba597e4b5420 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -32,11 +32,19 @@ use std::sync::Arc; /// /// This interface provides a way to implement custom schema adaptation logic /// for ParquetExec (for example, to fill missing columns with default value -/// other than null) +/// other than null). +/// +/// Most users should use [`DefaultSchemaAdapterFactory`]. See that struct for +/// more details and examples. pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { - /// Provides `SchemaAdapter`. - // The design of this function is mostly modeled for the needs of DefaultSchemaAdapterFactory, - // read its implementation docs for the reasoning + /// Create a [`SchemaAdapter`] + /// + /// Arguments: + /// + /// * `projected_table_schema`: The schema for the table, projected to + /// include only the fields being output (projected) by the this mapping. + /// + /// * `table_schema`: The entire table schema for the table fn create( &self, projected_table_schema: SchemaRef, @@ -44,53 +52,57 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { ) -> Box; } -/// Adapt file-level [`RecordBatch`]es to a table schema, which may have a schema -/// obtained from merging multiple file-level schemas. -/// -/// This is useful for enabling schema evolution in partitioned datasets. -/// -/// This has to be done in two stages. +/// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table +/// schema, which may have a schema obtained from merging multiple file-level +/// schemas. /// -/// 1. Before reading the file, we have to map projected column indexes from the -/// table schema to the file schema. +/// This is useful for implementing schema evolution in partitioned datasets. /// -/// 2. After reading a record batch map the read columns back to the expected -/// columns indexes and insert null-valued columns wherever the file schema was -/// missing a column present in the table schema. +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaAdapter: Send + Sync { /// Map a column index in the table schema to a column index in a particular /// file schema /// + /// This is used while reading a file to push down projections by mapping + /// projected column indexes from the table schema to the file schema + /// /// Panics if index is not in range for the table schema fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option; - /// Creates a `SchemaMapping` that can be used to cast or map the columns - /// from the file schema to the table schema. + /// Creates a mapping for casting columns from the file schema to the table + /// schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// This is used after reading a record batch. The returned [`SchemaMapper`]: /// - /// Returns a [`SchemaMapper`] that can be applied to the output batch - /// along with an ordered list of columns to project from the file + /// 1. Maps columns to the expected columns indexes + /// 2. Handles missing values (e.g. fills nulls or a default value) for + /// columns in the in the table schema not in the file schema + /// 2. Handles different types: if the column in the file schema has a + /// different type than `table_schema`, the mapper will resolve this + /// difference (e.g. by casting to the appropriate type) + /// + /// Returns: + /// * a [`SchemaMapper`] + /// * an ordered list of columns to project from the file fn map_schema( &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)>; } -/// Maps, by casting or reordering columns from the file schema to the table -/// schema. +/// Maps, columns from a specific file schema to the table schema. +/// +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaMapper: Debug + Send + Sync { - /// Adapts a `RecordBatch` to match the `table_schema` using the stored - /// mapping and conversions. + /// Adapts a `RecordBatch` to match the `table_schema` fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; /// Adapts a [`RecordBatch`] that does not have all the columns from the /// file schema. /// - /// This method is used when applying a filter to a subset of the columns as - /// part of `DataFusionArrowPredicate` when `filter_pushdown` is enabled. + /// This method is used, for example, when applying a filter to a subset of + /// the columns as part of `DataFusionArrowPredicate` when `filter_pushdown` + /// is enabled. /// /// This method is slower than `map_batch` as it looks up columns by name. fn map_partial_batch( @@ -99,11 +111,106 @@ pub trait SchemaMapper: Debug + Send + Sync { ) -> datafusion_common::Result; } -/// Implementation of [`SchemaAdapterFactory`] that maps columns by name -/// and casts columns to the expected type. +/// Default [`SchemaAdapterFactory`] for mapping schemas. +/// +/// This can be used to adapt file-level record batches to a table schema and +/// implement schema evolution. +/// +/// Given an input file schema and a table schema, this factory returns +/// [`SchemaAdapter`] that return [`SchemaMapper`]s that: +/// +/// 1. Reorder columns +/// 2. Cast columns to the correct type +/// 3. Fill missing columns with nulls +/// +/// # Errors: +/// +/// * If a column in the table schema is non-nullable but is not present in the +/// file schema (i.e. it is missing), the returned mapper tries to fill it with +/// nulls resulting in a schema error. +/// +/// # Illustration of Schema Mapping +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌───────┐ ┌───────┐ │ ┌───────┐ ┌───────┐ ┌───────┐ │ +/// ││ 1.0 │ │ "foo" │ ││ NULL │ │ "foo" │ │ "1.0" │ +/// ├───────┤ ├───────┤ │ Schema mapping ├───────┤ ├───────┤ ├───────┤ │ +/// ││ 2.0 │ │ "bar" │ ││ NULL │ │ "bar" │ │ "2.0" │ +/// └───────┘ └───────┘ │────────────────▶ └───────┘ └───────┘ └───────┘ │ +/// │ │ +/// column "c" column "b"│ column "a" column "b" column "c"│ +/// │ Float64 Utf8 │ Int32 Utf8 Utf8 +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Input Record Batch Output Record Batch +/// +/// Schema { Schema { +/// "c": Float64, "a": Int32, +/// "b": Utf8, "b": Utf8, +/// } "c": Utf8, +/// } +/// ``` +/// +/// # Example of using the `DefaultSchemaAdapterFactory` to map [`RecordBatch`]s +/// +/// Note `SchemaMapping` also supports mapping partial batches, which is used as +/// part of predicate pushdown. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapterFactory}; +/// # use datafusion_common::record_batch; +/// // Table has fields "a", "b" and "c" +/// let table_schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// Field::new("c", DataType::Utf8, true), +/// ]); +/// +/// // create an adapter to map the table schema to the file schema +/// let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); +/// +/// // The file schema has fields "c" and "b" but "b" is stored as an 'Float64' +/// // instead of 'Utf8' +/// let file_schema = Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Float64, true), +/// ]); +/// +/// // Get a mapping from the file schema to the table schema +/// let (mapper, _indices) = adapter.map_schema(&file_schema).unwrap(); +/// +/// let file_batch = record_batch!( +/// ("c", Utf8, vec!["foo", "bar"]), +/// ("b", Float64, vec![1.0, 2.0]) +/// ).unwrap(); +/// +/// let mapped_batch = mapper.map_batch(file_batch).unwrap(); +/// +/// // the mapped batch has the correct schema and the "b" column has been cast to Utf8 +/// let expected_batch = record_batch!( +/// ("a", Int32, vec![None, None]), // missing column filled with nulls +/// ("b", Utf8, vec!["1.0", "2.0"]), // b was cast to string and order was changed +/// ("c", Utf8, vec!["foo", "bar"]) +/// ).unwrap(); +/// assert_eq!(mapped_batch, expected_batch); +/// ``` #[derive(Clone, Debug, Default)] pub struct DefaultSchemaAdapterFactory; +impl DefaultSchemaAdapterFactory { + /// Create a new factory for mapping batches from a file schema to a table + /// schema. + /// + /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with + /// the same schema for both the projected table schema and the table + /// schema. + pub fn from_schema(table_schema: SchemaRef) -> Box { + Self.create(Arc::clone(&table_schema), table_schema) + } +} + impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { fn create( &self, @@ -117,8 +224,8 @@ impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { } } -/// This SchemaAdapter requires both the table schema and the projected table schema because of the -/// needs of the [`SchemaMapping`] it creates. Read its documentation for more details +/// This SchemaAdapter requires both the table schema and the projected table +/// schema. See [`SchemaMapping`] for more details #[derive(Clone, Debug)] pub(crate) struct DefaultSchemaAdapter { /// The schema for the table, projected to include only the fields being output (projected) by the @@ -142,11 +249,12 @@ impl SchemaAdapter for DefaultSchemaAdapter { Some(file_schema.fields.find(field.name())?.0) } - /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema. + /// Creates a `SchemaMapping` for casting or mapping the columns from the + /// file schema to the table schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// If the provided `file_schema` contains columns of a different type to + /// the expected `table_schema`, the method will attempt to cast the array + /// data from the file schema to the table schema where possible. /// /// Returns a [`SchemaMapping`] that can be applied to the output batch /// along with an ordered list of columns to project from the file @@ -189,36 +297,45 @@ impl SchemaAdapter for DefaultSchemaAdapter { } } -/// The SchemaMapping struct holds a mapping from the file schema to the table schema -/// and any necessary type conversions that need to be applied. +/// The SchemaMapping struct holds a mapping from the file schema to the table +/// schema and any necessary type conversions. +/// +/// Note, because `map_batch` and `map_partial_batch` functions have different +/// needs, this struct holds two schemas: +/// +/// 1. The projected **table** schema +/// 2. The full table schema /// -/// This needs both the projected table schema and full table schema because its different -/// functions have different needs. The [`map_batch`] function is only used by the ParquetOpener to -/// produce a RecordBatch which has the projected schema, since that's the schema which is supposed -/// to come out of the execution of this query. [`map_partial_batch`], however, is used to create a -/// RecordBatch with a schema that can be used for Parquet pushdown, meaning that it may contain -/// fields which are not in the projected schema (as the fields that parquet pushdown filters -/// operate can be completely distinct from the fields that are projected (output) out of the -/// ParquetExec). +/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which +/// has the projected schema, since that's the schema which is supposed to come +/// out of the execution of this query. Thus `map_batch` uses +/// `projected_table_schema` as it can only operate on the projected fields. /// -/// [`map_partial_batch`] uses `table_schema` to create the resulting RecordBatch (as it could be -/// operating on any fields in the schema), while [`map_batch`] uses `projected_table_schema` (as -/// it can only operate on the projected fields). +/// [`map_partial_batch`] is used to create a RecordBatch with a schema that +/// can be used for Parquet predicate pushdown, meaning that it may contain +/// fields which are not in the projected schema (as the fields that parquet +/// pushdown filters operate can be completely distinct from the fields that are +/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses +/// `table_schema` to create the resulting RecordBatch (as it could be operating +/// on any fields in the schema). /// /// [`map_batch`]: Self::map_batch /// [`map_partial_batch`]: Self::map_partial_batch #[derive(Debug)] pub struct SchemaMapping { - /// The schema of the table. This is the expected schema after conversion and it should match - /// the schema of the query result. + /// The schema of the table. This is the expected schema after conversion + /// and it should match the schema of the query result. projected_table_schema: SchemaRef, - /// Mapping from field index in `projected_table_schema` to index in projected file_schema. - /// They are Options instead of just plain `usize`s because the table could have fields that - /// don't exist in the file. + /// Mapping from field index in `projected_table_schema` to index in + /// projected file_schema. + /// + /// They are Options instead of just plain `usize`s because the table could + /// have fields that don't exist in the file. field_mappings: Vec>, - /// The entire table schema, as opposed to the projected_table_schema (which only contains the - /// columns that we are projecting out of this query). This contains all fields in the table, - /// regardless of if they will be projected out or not. + /// The entire table schema, as opposed to the projected_table_schema (which + /// only contains the columns that we are projecting out of this query). + /// This contains all fields in the table, regardless of if they will be + /// projected out or not. table_schema: SchemaRef, } @@ -304,7 +421,8 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = Arc::new(Schema::new(fields)); + let schema = + Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } @@ -330,8 +448,9 @@ mod tests { use crate::datasource::listing::PartitionedFile; use crate::datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + use datafusion_common::record_batch; #[cfg(feature = "parquet")] use parquet::arrow::ArrowWriter; use tempfile::TempDir; @@ -359,7 +478,7 @@ mod tests { writer.close().unwrap(); let location = Path::parse(path.to_str().unwrap()).unwrap(); - let metadata = std::fs::metadata(path.as_path()).expect("Local file metadata"); + let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); let meta = ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), @@ -404,6 +523,58 @@ mod tests { assert_batches_sorted_eq!(expected, &read); } + #[test] + fn default_schema_adapter() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + // file has a subset of the table schema fields and different type + let file_schema = Schema::new(vec![ + Field::new("c", DataType::Float64, true), // not in table schema + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![1]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // the mapped batch has the correct schema and the "b" column has been cast to Utf8 + let expected_batch = record_batch!( + ("a", Int32, vec![None, None]), // missing column filled with nulls + ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed + ) + .unwrap(); + assert_eq!(mapped_batch, expected_batch); + } + + #[test] + fn default_schema_adapter_non_nullable_columns() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), // "a"" is declared non nullable + Field::new("b", DataType::Utf8, true), + ]); + let file_schema = Schema::new(vec![ + // since file doesn't have "a" it will be filled with nulls + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![0]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + // Mapping fails because it tries to fill in a non-nullable column with nulls + let err = mapper.map_batch(file_batch).unwrap_err().to_string(); + assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + } + #[derive(Debug)] struct TestSchemaAdapterFactory; diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index d30247e2c67a0..34023fbbb6207 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,6 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -350,7 +351,7 @@ impl TableProvider for StreamTable { &self, _state: &dyn Session, input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { let ordering = match self.0.order.first() { Some(x) => { diff --git a/datafusion/core/src/execution/context/avro.rs b/datafusion/core/src/execution/context/avro.rs index e829f6123eab4..a31f2af642d04 100644 --- a/datafusion/core/src/execution/context/avro.rs +++ b/datafusion/core/src/execution/context/avro.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use super::super::options::{AvroReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, Result, SessionContext}; +use datafusion_common::TableReference; +use std::sync::Arc; impl SessionContext { /// Creates a [`DataFrame`] for reading an Avro data source. @@ -39,15 +39,15 @@ impl SessionContext { /// SQL statements executed against this context. pub async fn register_avro( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: AvroReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs index 08e93cb613056..e97c70ef98121 100644 --- a/datafusion/core/src/execution/context/csv.rs +++ b/datafusion/core/src/execution/context/csv.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::datasource::physical_plan::plan_to_csv; +use datafusion_common::TableReference; +use std::sync::Arc; use super::super::options::{CsvReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; @@ -55,15 +55,15 @@ impl SessionContext { /// statements executed against this context. pub async fn register_csv( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: CsvReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index c21e32cfdefbf..c9a9492f9162e 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::datasource::physical_plan::plan_to_json; +use datafusion_common::TableReference; +use std::sync::Arc; use super::super::options::{NdJsonReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; @@ -41,15 +41,15 @@ impl SessionContext { /// from SQL statements executed against this context. pub async fn register_json( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: NdJsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 53eb7c431b475..333f83c673cc2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -738,6 +738,11 @@ impl SessionContext { cmd: &CreateExternalTable, ) -> Result { let exist = self.table_exist(cmd.name.clone())?; + + if cmd.temporary { + return not_impl_err!("Temporary tables not supported"); + } + if exist { match cmd.if_not_exists { true => return self.return_empty_dataframe(), @@ -761,10 +766,16 @@ impl SessionContext { or_replace, constraints, column_defaults, + temporary, } = cmd; let input = Arc::unwrap_or_clone(input); let input = self.state().optimize(&input)?; + + if temporary { + return not_impl_err!("Temporary tables not supported"); + } + let table = self.table(name.clone()).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -813,10 +824,15 @@ impl SessionContext { input, or_replace, definition, + temporary, } = cmd; let view = self.table(name.clone()).await; + if temporary { + return not_impl_err!("Temporary views not supported"); + } + match (or_replace, view) { (true, Ok(_)) => { self.deregister_table(name.clone())?; @@ -1264,7 +1280,7 @@ impl SessionContext { /// [`ObjectStore`]: object_store::ObjectStore pub async fn register_listing_table( &self, - name: &str, + table_ref: impl Into, table_path: impl AsRef, options: ListingOptions, provided_schema: Option, @@ -1279,10 +1295,7 @@ impl SessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?.with_definition(sql_definition); - self.register_table( - TableReference::Bare { table: name.into() }, - Arc::new(table), - )?; + self.register_table(table_ref, Arc::new(table))?; Ok(()) } @@ -1550,7 +1563,7 @@ impl From for SessionStateBuilder { /// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] -pub trait QueryPlanner { +pub trait QueryPlanner: Debug { /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, @@ -1563,7 +1576,7 @@ pub trait QueryPlanner { /// and interact with [SessionState] to registers new udf, udaf or udwf. #[async_trait] -pub trait FunctionFactory: Sync + Send { +pub trait FunctionFactory: Debug + Sync + Send { /// Handles creation of user defined function specified in [CreateFunction] statement async fn create( &self, @@ -1586,6 +1599,7 @@ pub enum RegisterFunction { /// Default implementation of [SerializerRegistry] that throws unimplemented error /// for all requests. +#[derive(Debug)] pub struct EmptySerializerRegistry; impl SerializerRegistry for EmptySerializerRegistry { @@ -2125,13 +2139,14 @@ mod tests { fn create_physical_expr( &self, _expr: &Expr, - _input_dfschema: &crate::common::DFSchema, + _input_dfschema: &DFSchema, _session_state: &SessionState, - ) -> Result> { + ) -> Result> { unimplemented!() } } + #[derive(Debug)] struct MyQueryPlanner {} #[async_trait] diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 1d83c968c1a89..3f23c150be839 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -21,6 +21,7 @@ use super::super::options::{ParquetReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; use crate::datasource::physical_plan::parquet::plan_to_parquet; +use datafusion_common::TableReference; use parquet::file::properties::WriterProperties; impl SessionContext { @@ -42,15 +43,15 @@ impl SessionContext { /// statements executed against this context. pub async fn register_parquet( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: ParquetReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3e6577a486084..d50c912dd2fdc 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -174,27 +174,30 @@ pub struct SessionState { } impl Debug for SessionState { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SessionState") .field("session_id", &self.session_id) - .field("analyzer", &"...") - .field("expr_planners", &"...") - .field("optimizer", &"...") - .field("physical_optimizers", &"...") - .field("query_planner", &"...") - .field("catalog_list", &"...") - .field("table_functions", &"...") + .field("config", &self.config) + .field("runtime_env", &self.runtime_env) + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) + .field("execution_props", &self.execution_props) + .field("table_options", &self.table_options) + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer", &self.analyzer) + .field("optimizer", &self.optimizer) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) - .field("serializer_registry", &"...") - .field("config", &self.config) - .field("table_options", &self.table_options) - .field("execution_props", &self.execution_props) - .field("table_factories", &"...") - .field("runtime_env", &self.runtime_env) - .field("function_factory", &"...") - .finish_non_exhaustive() + .finish() } } @@ -509,7 +512,7 @@ impl SessionState { /// [`catalog::resolve_table_references`]: crate::catalog_common::resolve_table_references pub fn resolve_table_references( &self, - statement: &datafusion_sql::parser::Statement, + statement: &Statement, ) -> datafusion_common::Result> { let enable_ident_normalization = self.config.options().sql_parser.enable_ident_normalization; @@ -523,7 +526,7 @@ impl SessionState { /// Convert an AST Statement into a LogicalPlan pub async fn statement_to_plan( &self, - statement: datafusion_sql::parser::Statement, + statement: Statement, ) -> datafusion_common::Result { let references = self.resolve_table_references(&statement)?; @@ -1519,6 +1522,37 @@ impl SessionStateBuilder { } } +impl Debug for SessionStateBuilder { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionStateBuilder") + .field("session_id", &self.session_id) + .field("config", &self.config) + .field("runtime_env", &self.runtime_env) + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) + .field("execution_props", &self.execution_props) + .field("table_options", &self.table_options) + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer_rules", &self.analyzer_rules) + .field("analyzer", &self.analyzer) + .field("optimizer_rules", &self.optimizer_rules) + .field("optimizer", &self.optimizer) + .field("physical_optimizer_rules", &self.physical_optimizer_rules) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) + .field("scalar_functions", &self.scalar_functions) + .field("aggregate_functions", &self.aggregate_functions) + .field("window_functions", &self.window_functions) + .finish() + } +} + impl Default for SessionStateBuilder { fn default() -> Self { Self::new() @@ -1795,6 +1829,7 @@ impl From<&SessionState> for TaskContext { } /// The query planner used if no user defined planner is provided +#[derive(Debug)] struct DefaultQueryPlanner {} #[async_trait] diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index c971e61506339..aa4bcb6837493 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1416,8 +1416,8 @@ pub(crate) mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{ - expressions, expressions::binary, expressions::lit, LexOrdering, - PhysicalSortExpr, PhysicalSortRequirement, + expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, + PhysicalSortRequirement, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::PlanProperties; @@ -1646,8 +1646,7 @@ pub(crate) mod tests { .enumerate() .map(|(index, (_col, name))| { ( - Arc::new(expressions::Column::new(name, index)) - as Arc, + Arc::new(Column::new(name, index)) as Arc, name.clone(), ) }) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 499fb9cbbcf03..1c63df1f0281f 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -140,20 +140,32 @@ fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, projection: Option<&Vec>, + join_type: &JoinType, ) -> Option> { - projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }) + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } } /// This function swaps the inputs of the given join operator. @@ -179,17 +191,20 @@ pub fn swap_hash_join( left.schema().fields().len(), right.schema().fields().len(), hash_join.projection.as_ref(), + hash_join.join_type(), ), partition_mode, hash_join.null_equals_null(), )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( hash_join.join_type(), JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - ) { + ) || hash_join.projection.is_some() + { Ok(Arc::new(new_join)) } else { // TODO avoid adding ProjectionExec again and again, only adding Final Projection @@ -1287,6 +1302,65 @@ mod tests_statistical { ); } + #[rstest( + join_type, projection, small_on_right, + case::inner(JoinType::Inner, vec![1], true), + case::left(JoinType::Left, vec![1], true), + case::right(JoinType::Right, vec![1], true), + case::full(JoinType::Full, vec![1], true), + case::left_anti(JoinType::LeftAnti, vec![0], false), + case::left_semi(JoinType::LeftSemi, vec![0], false), + case::right_anti(JoinType::RightAnti, vec![0], true), + case::right_semi(JoinType::RightSemi, vec![0], true), + )] + #[tokio::test] + async fn test_hash_join_swap_on_joins_with_projections( + join_type: JoinType, + projection: Vec, + small_on_right: bool, + ) -> Result<()> { + let (big, small) = create_big_and_small(); + + let left = if small_on_right { &big } else { &small }; + let right = if small_on_right { &small } else { &big }; + + let left_on = if small_on_right { + "big_col" + } else { + "small_col" + }; + let right_on = if small_on_right { + "small_col" + } else { + "big_col" + }; + + let join = Arc::new(HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + vec![( + Arc::new(Column::new_with_schema(left_on, &left.schema())?), + Arc::new(Column::new_with_schema(right_on, &right.schema())?), + )], + None, + &join_type, + Some(projection), + PartitionMode::Partitioned, + false, + )?); + + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) + .expect("swap_hash_join must support joins with projections"); + let swapped_join = swapped.as_any().downcast_ref::().expect( + "ProjectionExec won't be added above if HashJoinExec contains embedded projection", + ); + + assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped.schema().fields.len(), 1); + assert_eq!(swapped.schema().fields[0].name(), "small_col"); + Ok(()) + } + #[tokio::test] async fn test_swap_reverting_projection() { let left_schema = Schema::new(vec![ diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9bc2bb1d1db98..eb03b337779c1 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -458,7 +458,7 @@ pub trait PruningStatistics { /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10 -///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 +/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -478,6 +478,36 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain +/// complex expressions or predicates that reference columns that are not in the +/// schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be rewritten in terms of statistics or + /// references a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +/// The default handling for unhandled predicates is to return a constant `true` +/// (meaning don't prune the container) +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl Default for ConstantUnhandledPredicateHook { + fn default() -> Self { + Self { + default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))), + } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + self.default.clone() + } +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -502,10 +532,16 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -1312,27 +1348,78 @@ fn build_is_null_column_expr( /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; +/// Rewrite a predicate expression in terms of statistics (min/max/null_counts) +/// for use as a [`PruningPredicate`]. +pub struct PredicateRewriter { + unhandled_hook: Arc, +} + +impl Default for PredicateRewriter { + fn default() -> Self { + Self { + unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()), + } + } +} + +impl PredicateRewriter { + /// Create a new `PredicateRewriter` + pub fn new() -> Self { + Self::default() + } + + /// Set the unhandled hook to be used when a predicate can not be rewritten + pub fn with_unhandled_hook( + self, + unhandled_hook: Arc, + ) -> Self { + Self { unhandled_hook } + } + + /// Translate logical filter expression into pruning predicate + /// expression that will evaluate to FALSE if it can be determined no + /// rows between the min/max values could pass the predicates. + /// + /// Any predicates that can not be translated will be passed to `unhandled_hook`. + /// + /// Returns the pruning predicate as an [`PhysicalExpr`] + /// + /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` + pub fn rewrite_predicate_to_statistics_predicate( + &self, + expr: &Arc, + schema: &Schema, + ) -> Arc { + let mut required_columns = RequiredColumns::new(); + build_predicate_expression( + expr, + schema, + &mut required_columns, + &self.unhandled_hook, + ) + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// /// Returns the pruning predicate as an [`PhysicalExpr`] /// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1341,19 +1428,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1382,9 +1469,14 @@ fn build_predicate_expression( }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } @@ -1396,13 +1488,15 @@ fn build_predicate_expression( bin_expr.right().clone(), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1410,7 +1504,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1423,12 +1517,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1582,6 +1675,8 @@ mod tests { use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; + use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] @@ -3397,6 +3492,74 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42)))) + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema_with_b = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let rewriter = PredicateRewriter::new() + .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); + + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) + }; + + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(12)); + let known_expression_transformed = PredicateRewriter::new() + .rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + ); + + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(12)); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3886,6 +4049,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs index c0d9140c025e5..26cdd65883e41 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -131,10 +131,10 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { /// successfully. Any errors occurring during the conversion process are /// passed through. fn try_convert_aggregate_if_better( - aggr_exprs: Vec, + aggr_exprs: Vec>, prefix_requirement: &[PhysicalSortRequirement], eq_properties: &EquivalenceProperties, -) -> Result> { +) -> Result>> { aggr_exprs .into_iter() .map(|aggr_expr| { @@ -154,7 +154,7 @@ fn try_convert_aggregate_if_better( let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); if eq_properties.ordering_satisfy_requirement(&reqs) { // Existing ordering satisfies the aggregator requirements: - aggr_expr.with_beneficial_ordering(true)? + aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) } else if eq_properties.ordering_satisfy_requirement(&concat_slices( prefix_requirement, &reverse_aggr_req, @@ -163,12 +163,14 @@ fn try_convert_aggregate_if_better( // given the existing ordering (if possible): aggr_expr .reverse_expr() + .map(Arc::new) .unwrap_or(aggr_expr) .with_beneficial_ordering(true)? + .map(Arc::new) } else { // There is no beneficial ordering present -- aggregation // will still work albeit in a less efficient mode. - aggr_expr.with_beneficial_ordering(false)? + aggr_expr.with_beneficial_ordering(false)?.map(Arc::new) } .ok_or_else(|| { plan_datafusion_err!( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b2b912d8add20..ffedc2d6b6ef2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window, + Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, }; use crate::logical_expr::{ Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, UserDefinedLogicalNode, }; -use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -71,15 +70,15 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, - StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery, + SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -529,7 +528,7 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols, - overwrite: false, + insert_op: InsertOp::Append, keep_partition_by_columns, }; @@ -542,7 +541,7 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::InsertInto, + op: WriteOp::Insert(insert_op), .. }) => { let name = table_name.table(); @@ -550,23 +549,7 @@ impl DefaultPhysicalPlanner { if let Some(provider) = schema.table(name).await? { let input_exec = children.one()?; provider - .insert_into(session_state, input_exec, false) - .await? - } else { - return exec_err!("Table '{table_name}' does not exist"); - } - } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::InsertOverwrite, - .. - }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { - let input_exec = children.one()?; - provider - .insert_into(session_state, input_exec, true) + .insert_into(session_state, input_exec, *insert_op) .await? } else { return exec_err!("Table '{table_name}' does not exist"); @@ -708,10 +691,6 @@ impl DefaultPhysicalPlanner { physical_input_schema.clone(), )?); - // update group column indices based on partial aggregate plan evaluation - let final_group: Vec> = - initial_aggr.output_group_expr(); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); @@ -732,13 +711,7 @@ impl DefaultPhysicalPlanner { AggregateMode::Final }; - let final_grouping_set = PhysicalGroupBy::new_single( - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) - .collect(), - ); + let final_grouping_set = initial_aggr.group_expr().as_final(); Arc::new(AggregateExec::try_new( next_partition_mode, @@ -822,8 +795,20 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(limit) => { let input = children.one()?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!( + "Unsupported OFFSET expression: {:?}", + limit.skip + ); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!( + "Unsupported LIMIT expression: {:?}", + limit.fetch + ); + }; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -832,13 +817,13 @@ impl DefaultPhysicalPlanner { // Apply a LocalLimitExec to each partition. The optimizer will also insert // a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec if let Some(fetch) = fetch { - Arc::new(LocalLimitExec::new(input, *fetch + skip)) + Arc::new(LocalLimitExec::new(input, fetch + skip)) } else { input } }; - Arc::new(GlobalLimitExec::new(input, *skip, *fetch)) + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } LogicalPlan::Unnest(Unnest { list_type_columns, @@ -1040,14 +1025,21 @@ impl DefaultPhysicalPlanner { }) .collect(); + let metadata: HashMap<_, _> = left_df_schema + .metadata() + .clone() + .into_iter() + .chain(right_df_schema.metadata().clone()) + .collect(); + // Construct intermediate schemas used for filtering data and // convert logical expression to physical according to filter schema let filter_df_schema = DFSchema::new_with_metadata( filter_df_fields, - HashMap::new(), + metadata.clone(), )?; let filter_schema = - Schema::new_with_metadata(filter_fields, HashMap::new()); + Schema::new_with_metadata(filter_fields, metadata); let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1071,14 +1063,18 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, - join_type, - )?) + if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + // cross join if there is no join conditions and no join filter set + Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else { + // there is no equal join condition, use the nested loop join + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join @@ -1138,10 +1134,6 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::CrossJoin(_) => { - let [left, right] = children.two()?; - Arc::new(CrossJoinExec::new(left, right)) - } LogicalPlan::RecursiveQuery(RecursiveQuery { name, is_distinct, .. }) => { @@ -1549,7 +1541,7 @@ pub fn create_window_expr( } type AggregateExprWithOptionalArgs = ( - AggregateFunctionExpr, + Arc, // The filter clause, if any Option>, // Ordering requirements, if any @@ -1613,7 +1605,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .alias(name) .with_ignore_nulls(ignore_nulls) .with_distinct(*distinct) - .build()?; + .build() + .map(Arc::new)?; (agg_expr, filter, physical_sort_exprs) }; @@ -2361,7 +2354,7 @@ mod tests { .expect("hash aggregate"); assert_eq!( "sum(aggregate_test_100.c3)", - final_hash_agg.schema().field(2).name() + final_hash_agg.schema().field(3).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde @@ -2573,6 +2566,10 @@ mod tests { ) -> Result { unimplemented!("NoOp"); } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug)] diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 08740daa0c8e7..9ac75c8f3efb3 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -69,7 +69,7 @@ pub fn create_table_dual() -> Arc { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), Arc::new(array::StringArray::from(vec!["a"])), ], ) diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 79e5056e3cf5b..e0917e6cca198 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -24,6 +24,9 @@ mod dataframe; /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; +/// Run all tests that are found in the `execution` directory +mod execution; + /// Run all tests that are found in the `expr_api` directory mod expr_api; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3520ab8fed2b3..0c3c2a99517ef 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1434,9 +1434,7 @@ async fn unnest_analyze_metrics() -> Result<()> { .explain(false, true)? .collect() .await?; - let formatted = arrow::util::pretty::pretty_format_batches(&results) - .unwrap() - .to_string(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs new file mode 100644 index 0000000000000..168bf484e5411 --- /dev/null +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::Int64Array; +use arrow_schema::{DataType, Field}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::logical_plan::{LogicalPlan, Values}; +use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_plan::collect; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +///! Logical plans need to provide stable semantics, as downstream projects +///! create them and depend on them. Test executable semantics of logical plans. + +#[tokio::test] +async fn count_only_nulls() -> Result<()> { + // Input: VALUES (NULL), (NULL), (NULL) AS _(col) + let input_schema = Arc::new(DFSchema::from_unqualified_fields( + vec![Field::new("col", DataType::Null, true)].into(), + HashMap::new(), + )?); + let input = Arc::new(LogicalPlan::Values(Values { + schema: input_schema, + values: vec![ + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + ], + })); + let input_col_ref = Expr::Column(Column { + relation: None, + name: "col".to_string(), + }); + + // Aggregation: count(col) AS count + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input, + vec![], + vec![Expr::AggregateFunction(AggregateFunction { + func: Arc::new(AggregateUDF::new_from_impl(Count::new())), + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + })], + )?); + + // Execute and verify results + let session_state = SessionStateBuilder::new().build(); + let physical_plan = session_state.create_physical_plan(&aggregate).await?; + let result = + collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?; + + let result = only(result.as_slice()); + let result_schema = result.schema(); + let field = only(result_schema.fields().deref()); + let column = only(result.columns()); + + assert_eq!(field.data_type(), &DataType::Int64); // TODO should be UInt64 + assert_eq!(column.deref(), &Int64Array::from(vec![0])); + + Ok(()) +} + +fn only(elements: &[T]) -> &T +where + T: Debug, +{ + let [element] = elements else { + panic!("Expected exactly one element, got {:?}", elements); + }; + element +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs new file mode 100644 index 0000000000000..8169db1a4611e --- /dev/null +++ b/datafusion/core/tests/execution/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod logical_plan; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index cbd8926721529..81a33361008f0 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -37,14 +37,14 @@ mod simplification; fn test_octet_length() { #[rustfmt::skip] evaluate_expr_test( - octet_length(col("list")), + octet_length(col("id")), vec![ "+------+", "| expr |", "+------+", - "| 5 |", - "| 18 |", - "| 6 |", + "| 1 |", + "| 1 |", + "| 1 |", "+------+", ], ); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index d7995d4663be4..68785b7a5a45c 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -29,10 +29,10 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr, table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarUDF, Volatility, + table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarUDF, Volatility, }; -use datafusion_functions::{math, string}; +use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; @@ -333,8 +333,8 @@ fn simplify_scan_predicate() -> Result<()> { .build()?; // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; + // after simplify: t.g = t.f" + let expected = "TableScan: test, full_filters=[g = f]"; let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) @@ -368,13 +368,13 @@ fn test_const_evaluator() { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]); + let expr = concat(vec![lit("foo"), lit("bar")]); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]); - let expr = string::expr_fn::concat(vec![lit("foo"), concat1]); + let concat1 = concat(vec![lit("bar"), lit("baz")]); + let expr = concat(vec![lit("foo"), concat1]); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments @@ -407,7 +407,7 @@ fn test_const_evaluator_scalar_functions() { #[test] fn test_const_evaluator_now() { let ts_nanos = 1599566400000000000i64; - let time = chrono::Utc.timestamp_nanos(ts_nanos); + let time = Utc.timestamp_nanos(ts_nanos); let ts_string = "2020-09-08T12:05:00+00:00"; // now() --> ts test_evaluate_with_start_time(now(), lit_timestamp_nano(ts_nanos), &time); @@ -429,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + let expr = Expr::ScalarFunction(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -438,21 +438,16 @@ fn test_evaluator_udfs() { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - args.clone(), - )); + let expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = - Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); - let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - folded_args, - )); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983cb..28901b14b5b7d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -44,6 +44,154 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; +use crate::fuzz_cases::aggregation_fuzzer::{ + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, +}; + +// ======================================================================== +// The new aggregation fuzz tests based on [`AggregationFuzzer`] +// ======================================================================== +// +// Notes on tests: +// +// Since the supported types differ for each aggregation function, the tests +// below are structured so they enumerate each different aggregate function. +// +// The test framework handles varying combinations of arguments (data types), +// sortedness, and grouping parameters +// +// TODO: Test floating point values (where output needs to be compared with some +// acceptable range due to floating point rounding) +// +// TODO: test other aggregate functions +// - AVG (unstable given the wide range of inputs) +// +// TODO: specific test for ordering (ensure all group by columns are ordered) +// Currently the data is sorted by random columns, so there are almost no +// repeated runs. To improve coverage we should also sort by lower cardinality columns +#[tokio::test(flavor = "multi_thread")] +async fn test_min() { + let data_gen_config = baseline_config(); + + // Queries like SELECT min(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("min") + // min works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_max() { + let data_gen_config = baseline_config(); + + // Queries like SELECT max(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("max") + // max works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_sum() { + let data_gen_config = baseline_config(); + + // Queries like SELECT sum(a), sum(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("sum") + .with_distinct_aggregate_function("sum") + // sum only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count() { + let data_gen_config = baseline_config(); + + // Queries like SELECT count(a), count(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("count") + .with_distinct_aggregate_function("count") + // count work for all arguments + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +/// Return a standard set of columns for testing data generation +/// +/// Includes numeric and string types +/// +/// Does not include: +/// 1. Floating point numbers +/// 1. structured types +fn baseline_config() -> DatasetGeneratorConfig { + let columns = vec![ + ColumnDescr::new("i8", DataType::Int8), + ColumnDescr::new("i16", DataType::Int16), + ColumnDescr::new("i32", DataType::Int32), + ColumnDescr::new("i64", DataType::Int64), + ColumnDescr::new("u8", DataType::UInt8), + ColumnDescr::new("u16", DataType::UInt16), + ColumnDescr::new("u32", DataType::UInt32), + ColumnDescr::new("u64", DataType::UInt64), + ColumnDescr::new("date32", DataType::Date32), + ColumnDescr::new("date64", DataType::Date64), + // TODO: date/time columns + // todo decimal columns + // begin string columns + ColumnDescr::new("utf8", DataType::Utf8), + ColumnDescr::new("largeutf8", DataType::LargeUtf8), + // TODO add support for utf8view in data generator + // ColumnDescr::new("utf8view", DataType::Utf8View), + // todo binary + ]; + + DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set: vec![ + // low cardinality to try and get many repeated runs + vec![String::from("u8")], + vec![String::from("utf8"), String::from("u8")], + ], + } +} + +// ======================================================================== +// The old aggregation fuzz tests +// ======================================================================== + +/// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -58,7 +206,7 @@ async fn streaming_aggregate_test() { vec!["d", "c", "a"], vec!["d", "c", "b", "a"], ]; - let n = 300; + let n = 10; let distincts = vec![10, 20]; for distinct in distincts { let mut join_set = JoinSet::new(); @@ -100,7 +248,8 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys]), + .try_with_sort_information(vec![sort_keys]) + .unwrap(), ); let aggregate_expr = @@ -109,6 +258,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .schema(Arc::clone(&schema)) .alias("sum1") .build() + .map(Arc::new) .unwrap(), ]; let expr = group_by_columns @@ -311,6 +461,7 @@ async fn group_by_string_test( let actual = extract_result_counts(results); assert_eq!(expected, actual); } + async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { struct Visitor { expected_sort: bool, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs new file mode 100644 index 0000000000000..af454bee7ce81 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, sync::Arc}; + +use datafusion::{ + datasource::MemTable, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::TableProvider; +use datafusion_common::error::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::col; +use rand::{thread_rng, Rng}; + +use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; + +/// SessionContext generator +/// +/// During testing, `generate_baseline` will be called firstly to generate a standard [`SessionContext`], +/// and we will run `sql` on it to get the `expected result`. Then `generate` will be called some times to +/// generate some random [`SessionContext`]s, and we will run the same `sql` on them to get `actual results`. +/// Finally, we compare the `actual results` with `expected result`, the test only success while all they are +/// same with the expected. +/// +/// Following parameters of [`SessionContext`] used in query running will be generated randomly: +/// - `batch_size` +/// - `target_partitions` +/// - `skip_partial parameters` +/// - hint `sorted` or not +/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed +/// to support this) +/// +pub struct SessionContextGenerator { + /// Current testing dataset + dataset: Arc, + + /// Table name of the test table + table_name: String, + + /// Used in generate the random `batch_size` + /// + /// The generated `batch_size` is between (0, total_rows_num] + max_batch_size: usize, + + /// Candidate `SkipPartialParams` which will be picked randomly + candidate_skip_partial_params: Vec, + + /// The upper bound of the randomly generated target partitions, + /// and the lower bound will be 1 + max_target_partitions: usize, +} + +impl SessionContextGenerator { + pub fn new(dataset_ref: Arc, table_name: &str) -> Self { + let candidate_skip_partial_params = vec![ + SkipPartialParams::ensure_trigger(), + SkipPartialParams::ensure_not_trigger(), + ]; + + let max_batch_size = cmp::max(1, dataset_ref.total_rows_num); + let max_target_partitions = num_cpus::get(); + + Self { + dataset: dataset_ref, + table_name: table_name.to_string(), + max_batch_size, + candidate_skip_partial_params, + max_target_partitions, + } + } +} + +impl SessionContextGenerator { + /// Generate the `SessionContext` for the baseline run + pub fn generate_baseline(&self) -> Result { + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // The baseline context should try best to disable all optimizations, + // and pursuing the rightness. + let batch_size = self.max_batch_size; + let target_partitions = 1; + let skip_partial_params = SkipPartialParams::ensure_not_trigger(); + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + skip_partial_params, + sort_hint: false, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } + + /// Randomly generate session context + pub fn generate(&self) -> Result { + let mut rng = thread_rng(); + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // We will randomly generate following options: + // - `batch_size`, from range: [1, `total_rows_num`] + // - `target_partitions`, from range: [1, cpu_num] + // - `skip_partial`, trigger or not trigger currently for simplicity + // - `sorted`, if found a sorted dataset, will or will not push down this information + // - `spilling`(TODO) + let batch_size = rng.gen_range(1..=self.max_batch_size); + + let target_partitions = rng.gen_range(1..=self.max_target_partitions); + + let skip_partial_params_idx = + rng.gen_range(0..self.candidate_skip_partial_params.len()); + let skip_partial_params = + self.candidate_skip_partial_params[skip_partial_params_idx]; + + let (provider, sort_hint) = + if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + // Sort keys exist and random to push down + let sort_exprs = self + .dataset + .sort_keys + .iter() + .map(|key| col(key).sort(true, true)) + .collect::>(); + (provider.with_sort_order(vec![sort_exprs]), true) + } else { + (provider, false) + }; + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + sort_hint, + skip_partial_params, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } +} + +/// The generated [`SessionContext`] with its params +/// +/// Storing the generated `params` is necessary for +/// reporting the broken test case. +pub struct SessionContextWithParams { + pub ctx: SessionContext, + pub params: SessionContextParams, +} + +/// Collect the generated params, and build the [`SessionContext`] +struct GeneratedSessionContextBuilder { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, + table_name: String, + table_provider: Arc, +} + +impl GeneratedSessionContextBuilder { + fn build(self) -> Result { + // Build session context + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &ScalarValue::UInt64(Some(self.batch_size as u64)), + ); + session_config = session_config.set( + "datafusion.execution.target_partitions", + &ScalarValue::UInt64(Some(self.target_partitions as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)), + ); + + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table(self.table_name, self.table_provider)?; + + let params = SessionContextParams { + batch_size: self.batch_size, + target_partitions: self.target_partitions, + sort_hint: self.sort_hint, + skip_partial_params: self.skip_partial_params, + }; + + Ok(SessionContextWithParams { ctx, params }) + } +} + +/// The generated params for [`SessionContext`] +#[derive(Debug)] +#[allow(dead_code)] +pub struct SessionContextParams { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, +} + +/// Partial skipping parameters +#[derive(Debug, Clone, Copy)] +pub struct SkipPartialParams { + /// Related to `skip_partial_aggregation_probe_ratio_threshold` in `ExecutionOptions` + pub ratio_threshold: f64, + + /// Related to `skip_partial_aggregation_probe_rows_threshold` in `ExecutionOptions` + pub rows_threshold: usize, +} + +impl SkipPartialParams { + /// Generate `SkipPartialParams` ensuring to trigger partial skipping + pub fn ensure_trigger() -> Self { + Self { + ratio_threshold: 0.0, + rows_threshold: 0, + } + } + + /// Generate `SkipPartialParams` ensuring not to trigger partial skipping + pub fn ensure_not_trigger() -> Self { + Self { + ratio_threshold: 1.0, + rows_threshold: usize::MAX, + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::{RecordBatch, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[tokio::test] + async fn test_generated_context() { + // 1. Define a test dataset firstly + let a_col: StringArray = [ + Some("rust"), + Some("java"), + Some("cpp"), + Some("go"), + Some("go1"), + Some("python"), + Some("python1"), + Some("python2"), + ] + .into_iter() + .collect(); + // Sort by "b" + let b_col: UInt32Array = [ + Some(1), + Some(2), + Some(4), + Some(8), + Some(8), + Some(16), + Some(16), + Some(16), + ] + .into_iter() + .collect(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::UInt32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a_col), Arc::new(b_col)], + ) + .unwrap(); + + // One row a group to create batches + let mut batches = Vec::with_capacity(batch.num_rows()); + for start in 0..batch.num_rows() { + let sub_batch = batch.slice(start, 1); + batches.push(sub_batch); + } + + let dataset = Dataset::new(batches, vec!["b".to_string()]); + + // 2. Generate baseline context, and some randomly session contexts. + // Run the same query on them, and all randoms' results should equal to baseline's + let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); + + let query = "select b, count(a) from fuzz_table group by b"; + let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let mut random_wrapped_ctxs = Vec::with_capacity(8); + for _ in 0..8 { + let ctx = ctx_generator.generate().unwrap(); + random_wrapped_ctxs.push(ctx); + } + + let base_result = baseline_wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + + for wrapped_ctx in random_wrapped_ctxs { + let random_result = wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + check_equality_of_batches(&base_result, &random_result).unwrap(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs new file mode 100644 index 0000000000000..ef9b5a7f355a7 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -0,0 +1,508 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{ + Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_plan::sorts::sort::sort_batch; +use rand::{ + rngs::{StdRng, ThreadRng}, + thread_rng, Rng, SeedableRng, +}; +use test_utils::{ + array_gen::{PrimitiveArrayGenerator, StringArrayGenerator}, + stagger_batch, +}; + +/// Config for Data sets generator +/// +/// # Parameters +/// - `columns`, you just need to define `column name`s and `column data type`s +/// fot the test datasets, and then they will be randomly generated from generator +/// when you can `generate` function +/// +/// - `rows_num_range`, the rows num of the datasets will be randomly generated +/// among this range +/// +/// - `sort_keys`, if `sort_keys` are defined, when you can `generate`, the generator +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned +/// +#[derive(Debug, Clone)] +pub struct DatasetGeneratorConfig { + /// Descriptions of columns in datasets, it's `required` + pub columns: Vec, + + /// Rows num range of the generated datasets, it's `required` + pub rows_num_range: (usize, usize), + + /// Additional optional sort keys + /// + /// The generated datasets always include a non-sorted copy. For each + /// element in `sort_keys_set`, an additional datasets is created that + /// is sorted by these values as well. + pub sort_keys_set: Vec>, +} + +impl DatasetGeneratorConfig { + /// return a list of all column names + pub fn all_columns(&self) -> Vec<&str> { + self.columns.iter().map(|d| d.name.as_str()).collect() + } + + /// return a list of column names that are "numeric" + pub fn numeric_columns(&self) -> Vec<&str> { + self.columns + .iter() + .filter_map(|d| { + if d.column_type.is_numeric() { + Some(d.name.as_str()) + } else { + None + } + }) + .collect() + } +} + +/// Dataset generator +/// +/// It will generate one random [`Dataset`]s when `generate` function is called. +/// +/// The generation logic in `generate`: +/// +/// - Randomly generate a base record from `batch_generator` firstly. +/// And `columns`, `rows_num_range` in `config`(detail can see `DataSetsGeneratorConfig`), +/// will be used in generation. +/// +/// - Sort the batch according to `sort_keys` in `config` to generator another +/// `len(sort_keys)` sorted batches. +/// +/// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, +/// and this multiple batches will be used to create the `Dataset`. +/// +pub struct DatasetGenerator { + batch_generator: RecordBatchGenerator, + sort_keys_set: Vec>, +} + +impl DatasetGenerator { + pub fn new(config: DatasetGeneratorConfig) -> Self { + let batch_generator = RecordBatchGenerator::new( + config.rows_num_range.0, + config.rows_num_range.1, + config.columns, + ); + + Self { + batch_generator, + sort_keys_set: config.sort_keys_set, + } + } + + pub fn generate(&self) -> Result> { + let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); + + // Generate the base batch (unsorted) + let base_batch = self.batch_generator.generate()?; + let batches = stagger_batch(base_batch.clone()); + let dataset = Dataset::new(batches, Vec::new()); + datasets.push(dataset); + + // Generate the related sorted batches + let schema = base_batch.schema_ref(); + for sort_keys in self.sort_keys_set.clone() { + let sort_exprs = sort_keys + .iter() + .map(|key| { + let col_expr = col(key, schema)?; + Ok(PhysicalSortExpr::new_default(col_expr)) + }) + .collect::>>()?; + let sorted_batch = sort_batch(&base_batch, &sort_exprs, None)?; + + let batches = stagger_batch(sorted_batch); + let dataset = Dataset::new(batches, sort_keys); + datasets.push(dataset); + } + + Ok(datasets) + } +} + +/// Single test data set +#[derive(Debug)] +pub struct Dataset { + pub batches: Vec, + pub total_rows_num: usize, + pub sort_keys: Vec, +} + +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + // Column name + name: String, + + // Data type of this column + column_type: DataType, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + } + } +} + +/// Record batch generator +struct RecordBatchGenerator { + min_rows_nun: usize, + + max_rows_num: usize, + + columns: Vec, + + candidate_null_pcts: Vec, +} + +macro_rules! generate_string_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $OFFSET_TYPE:ty) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let max_len = $BATCH_GEN_RNG.gen_range(1..50); + let num_distinct_strings = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = StringArrayGenerator { + max_len, + num_strings: $NUM_ROWS, + num_distinct_strings, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$OFFSET_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => { + paste::paste! {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let num_distinct_primitives = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}} +} + +impl RecordBatchGenerator { + fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + } + } + + fn generate(&self) -> Result { + let mut rng = thread_rng(); + let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(rng.gen()); + + // Build arrays + let mut arrays = Vec::with_capacity(self.columns.len()); + for col in self.columns.iter() { + let array = self.generate_array_of_type( + col.column_type.clone(), + num_rows, + &mut rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &self, + data_type: DataType, + num_rows: usize, + batch_gen_rng: &mut ThreadRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + match data_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Int8Type + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Int16Type + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Int32Type + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Int64Type + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + UInt8Type + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + UInt16Type + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + UInt32Type + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + UInt64Type + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Float32Type + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Float64Type + ) + } + DataType::Date32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Date32Type + ) + } + DataType::Date64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Date64Type + ) + } + DataType::Utf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i32) + } + DataType::LargeUtf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i64) + } + _ => { + panic!("Unsupported data generator type: {data_type}") + } + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::UInt32Array; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[test] + fn test_generated_datasets() { + // The test datasets generation config + // We expect that after calling `generate` + // - Generate 2 datasets + // - They have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + // - One of them is unsorted, another is sorted by column "b" + // - Their rows num should be same and between [16, 32] + let config = DatasetGeneratorConfig { + columns: vec![ + ColumnDescr { + name: "a".to_string(), + column_type: DataType::Utf8, + }, + ColumnDescr { + name: "b".to_string(), + column_type: DataType::UInt32, + }, + ], + rows_num_range: (16, 32), + sort_keys_set: vec![vec!["b".to_string()]], + }; + + let gen = DatasetGenerator::new(config); + let datasets = gen.generate().unwrap(); + + // Should Generate 2 datasets + assert_eq!(datasets.len(), 2); + + // Should have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + let check_fields = |batch: &RecordBatch| { + assert_eq!(batch.num_columns(), 2); + let fields = batch.schema().fields().clone(); + assert_eq!(fields[0].name(), "a"); + assert_eq!(*fields[0].data_type(), DataType::Utf8); + assert_eq!(fields[1].name(), "b"); + assert_eq!(*fields[1].data_type(), DataType::UInt32); + }; + + let batch = &datasets[0].batches[0]; + check_fields(batch); + let batch = &datasets[1].batches[0]; + check_fields(batch); + + // One batches should be sort by "b" + let sorted_batches = &datasets[1].batches; + let b_vals = sorted_batches.iter().flat_map(|batch| { + let uint_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + uint_array.iter() + }); + let mut prev_b_val = u32::MIN; + for b_val in b_vals { + let b_val = b_val.unwrap_or(u32::MIN); + assert!(b_val >= prev_b_val); + prev_b_val = b_val; + } + + // Two batches should be same after sorting + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches).unwrap(); + + // Rows num should between [16, 32] + let rows_num0 = datasets[0] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + let rows_num1 = datasets[1] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + assert_eq!(rows_num0, rows_num1); + assert!(rows_num0 >= 16); + assert!(rows_num0 <= 32); + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs new file mode 100644 index 0000000000000..0704bafa0318a --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -0,0 +1,508 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion_common::{DataFusionError, Result}; +use rand::{thread_rng, Rng}; +use tokio::task::JoinSet; + +use crate::fuzz_cases::aggregation_fuzzer::{ + check_equality_of_batches, + context_generator::{SessionContextGenerator, SessionContextWithParams}, + data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig}, + run_sql, +}; + +/// Rounds to call `generate` of [`SessionContextGenerator`] +/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`] +/// will generated for each dataset for testing. +const CTX_GEN_ROUNDS: usize = 16; + +/// Aggregation fuzzer's builder +pub struct AggregationFuzzerBuilder { + /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and required to set + candidate_sqls: Vec>, + + /// See `table_name` in [`AggregationFuzzer`], no default, and required to set + table_name: Option>, + + /// Used to generate `dataset_generator` in [`AggregationFuzzer`], + /// no default, and required to set + data_gen_config: Option, + + /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16 + data_gen_rounds: usize, +} + +impl AggregationFuzzerBuilder { + fn new() -> Self { + Self { + candidate_sqls: Vec::new(), + table_name: None, + data_gen_config: None, + data_gen_rounds: 16, + } + } + + /// Adds random SQL queries to the fuzzer along with the table name + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + const NUM_QUERIES: usize = 10; + for _ in 0..NUM_QUERIES { + self = self.add_sql(&query_builder.generate_query()); + } + self.table_name(query_builder.table_name()) + } + + fn add_sql(mut self, sql: &str) -> Self { + self.candidate_sqls.push(Arc::from(sql)); + self + } + + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = Some(Arc::from(table_name)); + self + } + + pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig) -> Self { + self.data_gen_config = Some(data_gen_config); + self + } + + pub fn build(self) -> AggregationFuzzer { + assert!(!self.candidate_sqls.is_empty()); + let candidate_sqls = self.candidate_sqls; + let table_name = self.table_name.expect("table_name is required"); + let data_gen_config = self.data_gen_config.expect("data_gen_config is required"); + let data_gen_rounds = self.data_gen_rounds; + + let dataset_generator = DatasetGenerator::new(data_gen_config); + + AggregationFuzzer { + candidate_sqls, + table_name, + dataset_generator, + data_gen_rounds, + } + } +} + +impl Default for AggregationFuzzerBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for AggregationFuzzerBuilder { + fn from(value: DatasetGeneratorConfig) -> Self { + Self::default().data_gen_config(value) + } +} + +/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], +/// and running them to check the correctness of the optimizations +/// (e.g. sorted, partial skipping, spilling...) +pub struct AggregationFuzzer { + /// Candidate test queries represented by sqls + candidate_sqls: Vec>, + + /// The queried table name + table_name: Arc, + + /// Dataset generator used to randomly generate datasets + dataset_generator: DatasetGenerator, + + /// Rounds to call `generate` of [`DatasetGenerator`], + /// len(sort_keys_set) + 1` datasets will be generated for testing. + /// + /// It is suggested to set value 2x or more bigger than num of + /// `candidate_sqls` for better test coverage. + data_gen_rounds: usize, +} + +/// Query group including the tested dataset and its sql query +struct QueryGroup { + dataset: Dataset, + sql: Arc, +} + +impl AggregationFuzzer { + /// Run the fuzzer, printing an error and panicking if any of the tasks fail + pub async fn run(&self) { + let res = self.run_inner().await; + + if let Err(e) = res { + // Print the error via `Display` so that it displays nicely (the default `unwrap()` + // prints using `Debug` which escapes newlines, and makes multi-line messages + // hard to read + println!("{e}"); + panic!("Error!"); + } + } + + async fn run_inner(&self) -> Result<()> { + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + + // Loop to generate datasets and its query + for _ in 0..self.data_gen_rounds { + // Generate datasets first + let datasets = self + .dataset_generator + .generate() + .expect("should success to generate dataset"); + + // Then for each of them, we random select a test sql for it + let query_groups = datasets + .into_iter() + .map(|dataset| { + let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql = self.candidate_sqls[sql_idx].clone(); + + QueryGroup { dataset, sql } + }) + .collect::>(); + + for q in &query_groups { + println!(" Testing with query {}", q.sql); + } + + let tasks = self.generate_fuzz_tasks(query_groups).await; + for task in tasks { + join_set.spawn(async move { task.run().await }); + } + } + + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.map_err(|e| { + DataFusionError::Internal(format!( + "AggregationFuzzer task error: {:?}", + e + )) + })??; + } + Ok(()) + } + + async fn generate_fuzz_tasks( + &self, + query_groups: Vec, + ) -> Vec { + let mut tasks = Vec::with_capacity(query_groups.len() * CTX_GEN_ROUNDS); + for QueryGroup { dataset, sql } in query_groups { + let dataset_ref = Arc::new(dataset); + let ctx_generator = + SessionContextGenerator::new(dataset_ref.clone(), &self.table_name); + + // Generate the baseline context, and get the baseline result firstly + let baseline_ctx_with_params = ctx_generator + .generate_baseline() + .expect("should success to generate baseline session context"); + let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) + .await + .expect("should success to run baseline sql"); + let baseline_result = Arc::new(baseline_result); + // Generate test tasks + for _ in 0..CTX_GEN_ROUNDS { + let ctx_with_params = ctx_generator + .generate() + .expect("should success to generate session context"); + let task = AggregationFuzzTestTask { + dataset_ref: dataset_ref.clone(), + expected_result: baseline_result.clone(), + sql: sql.clone(), + ctx_with_params, + }; + + tasks.push(task); + } + } + tasks + } +} + +/// One test task generated by [`AggregationFuzzer`] +/// +/// It includes: +/// - `expected_result`, the expected result generated by baseline [`SessionContext`] +/// (disable all possible optimizations for ensuring correctness). +/// +/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run +/// on it after, and check if the result is equal to expected. +/// +/// - `sql`, the selected test sql +/// +/// - `dataset_ref`, the input dataset, store it for error reported when found +/// the inconsistency between the one for `ctx` and `expected results`. +/// +struct AggregationFuzzTestTask { + /// Generated session context in current test case + ctx_with_params: SessionContextWithParams, + + /// Expected result in current test case + /// It is generate from `query` + `baseline session context` + expected_result: Arc>, + + /// The test query + /// Use sql to represent it currently. + sql: Arc, + + /// The test dataset for error reporting + dataset_ref: Arc, +} + +impl AggregationFuzzTestTask { + async fn run(&self) -> Result<()> { + let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) + .await + .map_err(|e| e.context(self.context_error_report()))?; + self.check_result(&task_result, &self.expected_result) + } + + fn check_result( + &self, + task_result: &[RecordBatch], + expected_result: &[RecordBatch], + ) -> Result<()> { + check_equality_of_batches(task_result, expected_result).map_err(|e| { + // If we found inconsistent result, we print the test details for reproducing at first + let message = format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + e.row_idx, + e.lhs_row, + e.rhs_row, + format_batches_with_limit(task_result), + format_batches_with_limit(expected_result), + format_batches_with_limit(&self.dataset_ref.batches), + ); + DataFusionError::Internal(message) + }) + } + + /// Returns a formatted error message + fn context_error_report(&self) -> String { + format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ) + } +} + +/// Pretty prints the `RecordBatch`es, limited to the first 100 rows +fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display { + const MAX_ROWS: usize = 100; + let mut row_count = 0; + let to_print = batches + .iter() + .filter_map(|b| { + if row_count >= MAX_ROWS { + None + } else if row_count + b.num_rows() > MAX_ROWS { + // output last rows before limit + let slice_len = MAX_ROWS - row_count; + let b = b.slice(0, slice_len); + row_count += slice_len; + Some(b) + } else { + row_count += b.num_rows(); + Some(b.clone()) + } + }) + .collect::>(); + + pretty_format_batches(&to_print).unwrap() +} + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default)] +pub struct QueryBuilder { + /// The name of the table to query + table_name: String, + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + /// Columns to be used in group by + group_by_columns: Vec, + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} +impl QueryBuilder { + pub fn new() -> Self { + Default::default() + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Add a column to be used in the group bys + pub fn with_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + let group_by = group_by.into_iter().map(String::from); + self.group_by_columns.extend(group_by); + self + } + + /// Add a column to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + pub fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + let mut query = String::from("SELECT "); + query.push_str(&self.random_aggregate_functions().join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = thread_rng(); + let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.gen_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{}", alias_gen); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + let function = format!("{function_name}({distinct}{argument}) as {alias}"); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = thread_rng(); + let idx = rng.gen_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to 3 group by columns to ensure coverage for large groups. With + /// larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = thread_rng(); + const MAX_GROUPS: usize = 3; + let max_groups = self.group_by_columns.len().max(MAX_GROUPS); + let num_group_by = rng.gen_range(1..max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by { + let idx = rng.gen_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs new file mode 100644 index 0000000000000..d93a5b7b9360b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::error::Result; + +mod context_generator; +mod data_generator; +mod fuzzer; + +pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use fuzzer::*; + +#[derive(Debug)] +pub(crate) struct InconsistentResult { + pub row_idx: usize, + pub lhs_row: String, + pub rhs_row: String, +} + +pub(crate) fn check_equality_of_batches( + lhs: &[RecordBatch], + rhs: &[RecordBatch], +) -> std::result::Result<(), InconsistentResult> { + let lhs_formatted_batches = pretty_format_batches(lhs).unwrap().to_string(); + let mut lhs_formatted_batches_sorted: Vec<&str> = + lhs_formatted_batches.trim().lines().collect(); + lhs_formatted_batches_sorted.sort_unstable(); + let rhs_formatted_batches = pretty_format_batches(rhs).unwrap().to_string(); + let mut rhs_formatted_batches_sorted: Vec<&str> = + rhs_formatted_batches.trim().lines().collect(); + rhs_formatted_batches_sorted.sort_unstable(); + + for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted + .iter() + .zip(&rhs_formatted_batches_sorted) + .enumerate() + { + if lhs_row != rhs_row { + return Err(InconsistentResult { + row_idx, + lhs_row: lhs_row.to_string(), + rhs_row: rhs_row.to_string(), + }); + } + } + + Ok(()) +} + +pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/mod.rs b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs new file mode 100644 index 0000000000000..2f8a38200bf12 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `EquivalenceProperties` fuzz testing + +mod ordering; +mod projection; +mod properties; +mod utils; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs new file mode 100644 index 0000000000000..94157e11702ca --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -0,0 +1,395 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + generate_table_for_eq_properties, generate_table_for_orderings, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + + Ok(()) +} + +// This test checks given a table is ordered with `[a ASC, b ASC, c ASC, d ASC]` and `[a ASC, c ASC, b ASC, d ASC]` +// whether the table is also ordered with `[a ASC, b ASC, d ASC]` and `[a ASC, c ASC, d ASC]` +// Since these orderings cannot be deduced, these orderings shouldn't be satisfied by the table generated. +// For background see discussion: https://github.com/apache/datafusion/issues/12700#issuecomment-2411134296 +#[test] +fn test_ordering_satisfy_on_data() -> Result<()> { + let schema = create_test_schema_2()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let orderings = vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ], + // [a ASC, c ASC, b ASC, d ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + + let batch = generate_table_for_orderings(orderings, schema, 1000, 10)?; + + // [a ASC, c ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC] can be deduced + let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(is_table_same_after_sort(ordering, batch.clone())?); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs new file mode 100644 index 0000000000000..c0c8517a612b4 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + apply_projection, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) +} + +#[test] +fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| Arc::clone(target)) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs new file mode 100644 index 0000000000000..e704fcacc3289 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + TestScalarUDF, +}; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs new file mode 100644 index 0000000000000..acc45fe0e591b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -0,0 +1,627 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +use datafusion::physical_plan::expressions::col; +use datafusion::physical_plan::expressions::Column; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; + +use itertools::izip; +use rand::prelude::*; + +pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, +) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) +} + +// Generate a schema which consists of 6 columns (a, b, c, d, e, f) +pub fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) +} + +/// Construct a schema with random ordering +/// among column a, b, c, d +/// where +/// Column [a=f] (e.g they are aliases). +/// Column e is constant. +pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f)?; + // Column e has constant value. + eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) +} + +// Apply projection to the input_data, return projected equivalence properties and record batch +pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, +) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(Arc::clone(&output_schema)) + } else { + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? + }; + + let projected_eq = input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) +} + +#[test] +fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) +} + +/// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. +/// +/// The function works by adding a unique column of ascending integers to the original table. This column ensures +/// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can +/// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce +/// deterministic sorting results. +/// +/// If the table remains the same after sorting with the added unique column, it indicates that the table was +/// already sorted according to `required_ordering` to begin with. +pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, +) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(Arc::clone(&unique_col)); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) +} + +// If we already generated a random result for one of the +// expressions in the equivalence classes. For other expressions in the same +// equivalence class use same result. This util gets already calculated result, when available. +fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, +) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(Arc::clone(res)); + } + } + None +} + +// Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) +pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) +} + +/// Construct a schema with following properties +/// Schema satisfies following orderings: +/// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] +/// and +/// Column [a=c] (e.g they are aliases). +pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + eq_properties.add_equal_conditions(col_a, col_c)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) +} + +// Generate a table that satisfies the given equivalence properties; i.e. +// equivalences, ordering equivalences, and constants. +pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.expr().as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(Arc::clone(&representative_array)); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) +} + +// Generate a table that satisfies the given orderings; +pub fn generate_table_for_orderings( + mut orderings: Vec, + schema: SchemaRef, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + assert!(!orderings.is_empty()); + // Sort the inner vectors by their lengths (longest first) + orderings.sort_by_key(|v| std::cmp::Reverse(v.len())); + + let arrays = schema + .fields + .iter() + .map(|field| { + ( + field.name(), + generate_random_f64_array(n_elem, n_distinct, &mut rng), + ) + }) + .collect::>(); + let batch = RecordBatch::try_from_iter(arrays)?; + + // Sort batch according to first ordering expression + let sort_columns = get_sort_columns(&batch, &orderings[0])?; + let sort_indices = lexsort_to_indices(&sort_columns, None)?; + let mut batch = take_record_batch(&batch, &sort_indices)?; + + // prune out rows that is invalid according to remaining orderings. + for ordering in orderings.iter().skip(1) { + let sort_columns = get_sort_columns(&batch, ordering)?; + + // Collect sort options and values into separate vectors. + let (sort_options, sort_col_values): (Vec<_>, Vec<_>) = sort_columns + .into_iter() + .map(|sort_col| (sort_col.options.unwrap(), sort_col.values)) + .unzip(); + + let mut cur_idx = 0; + let mut keep_indices = vec![cur_idx as u32]; + for next_idx in 1..batch.num_rows() { + let cur_row = get_row_at_idx(&sort_col_values, cur_idx)?; + let next_row = get_row_at_idx(&sort_col_values, next_idx)?; + + if compare_rows(&cur_row, &next_row, &sort_options)? != Ordering::Greater { + // next row satisfies ordering relation given, compared to the current row. + keep_indices.push(next_idx as u32); + cur_idx = next_idx; + } + } + // Only keep valid rows, that satisfies given ordering relation. + batch = take_record_batch(&batch, &UInt32Array::from_iter_values(keep_indices))?; + } + + Ok(batch) +} + +// Convert each tuple to PhysicalSortExpr +pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], +) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(*expr), + options: *options, + }) + .collect() +} + +// Convert each inner tuple to PhysicalSortExpr +pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], +) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() +} + +// Utility function to generate random f64 array +fn generate_random_f64_array( + n_elems: usize, + n_distinct: usize, + rng: &mut StdRng, +) -> ArrayRef { + let values: Vec = (0..n_elems) + .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) +} + +// Helper function to get sort columns from a batch +fn get_sort_columns( + batch: &RecordBatch, + ordering: LexOrderingRef, +) -> Result> { + ordering + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>() +} + +#[derive(Debug, Clone)] +pub struct TestScalarUDF { + pub(crate) signature: Signature, +} + +impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 96aa1be181f53..c8478db22bd4a 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -89,6 +90,7 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_inner_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -101,6 +103,7 @@ async fn test_inner_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_inner_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -113,6 +116,7 @@ async fn test_inner_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_left_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -125,8 +129,7 @@ async fn test_left_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 +#[allow(unused_qualifications)] async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -134,11 +137,12 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] +#[allow(unused_qualifications)] async fn test_right_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -151,8 +155,7 @@ async fn test_right_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 +#[allow(unused_qualifications)] async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -160,11 +163,12 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] +#[allow(unused_qualifications)] async fn test_full_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -177,6 +181,7 @@ async fn test_full_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] // flaky for HjSmj case // https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { @@ -191,6 +196,7 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_semi_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -203,6 +209,7 @@ async fn test_semi_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -215,6 +222,7 @@ async fn test_semi_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_anti_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -227,8 +235,7 @@ async fn test_anti_join_1k() { } #[tokio::test] -// flaky for HjSmj case, giving 1 rows difference sometimes -// https://github.com/apache/datafusion/issues/11555 +#[allow(unused_qualifications)] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -236,7 +243,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, NljHj], false) .await } @@ -454,6 +461,7 @@ impl JoinFuzzTestCase { /// `join_tests` - identifies what join types to test /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders, /// so it is easy to debug a test on top of the failed data + #[allow(unused_qualifications)] async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) { for batch_size in self.batch_sizes { let session_config = SessionConfig::new().with_batch_size(*batch_size); @@ -515,14 +523,11 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) - && join_tests.contains(&JoinTestType::NljHj) - && nlj_rows != hj_rows - { + if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); Self::save_partitioned_batches_as_parquet( &nlj_collected, diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 95d97709f3195..c52acdd82764c 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -341,7 +341,7 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - rand::thread_rng() + thread_rng() .sample_iter(rand::distributions::Alphanumeric) .take(len) .map(char::from) diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 69241571b4af0..49db0d31a8e9c 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,9 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod aggregation_fuzzer; +mod equivalence; + mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index fae4731569b69..4ba06ef1d2a60 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -37,8 +37,8 @@ use test_utils::{batches_to_vec, partitions_to_sorted_vec}; const KB: usize = 1 << 10; #[tokio::test] #[cfg_attr(tarpaulin, ignore)] -async fn test_sort_1k_mem() { - for (batch_size, should_spill) in [(5, false), (20000, true), (1000000, true)] { +async fn test_sort_10k_mem() { + for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] { SortTest::new() .with_int32_batches(batch_size) .with_pool_size(10 * KB) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 408cadc35f485..353db86683631 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -29,7 +29,7 @@ mod sp_repartition_fuzz_tests { metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, repartition::RepartitionExec, sorts::sort_preserving_merge::SortPreservingMergeExec, - sorts::streaming_merge::streaming_merge, + sorts::streaming_merge::StreamingMergeBuilder, stream::RecordBatchStreamAdapter, ExecutionPlan, Partitioning, }; @@ -174,7 +174,7 @@ mod sp_repartition_fuzz_tests { }) .unzip(); - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + let sort_arrs = lexsort(&sort_columns, None)?; for (idx, arr) in izip!(indices, sort_arrs) { schema_vec[idx] = Some(arr); } @@ -246,15 +246,14 @@ mod sp_repartition_fuzz_tests { MemoryConsumer::new("test".to_string()).register(context.memory_pool()); // Internally SortPreservingMergeExec uses this function for merging. - let res = streaming_merge( - streams, - schema, - &exprs, - BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), - 1, - None, - mem_reservation, - )?; + let res = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(schema) + .with_expressions(&exprs) + .with_metrics(BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0)) + .with_batch_size(1) + .with_reservation(mem_reservation) + .build()?; let res = collect(res).await?; // Contains the merged result. let res = concat_batches(&res[0].schema(), &res)?; @@ -359,7 +358,8 @@ mod sp_repartition_fuzz_tests { let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys.clone()]), + .try_with_sort_information(vec![sort_keys.clone()]) + .unwrap(), ); let hash_exprs = vec![col("c", &schema).unwrap()]; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index a6c2cf700cc4e..61b4e32ad6c9e 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -45,6 +45,8 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf}; +use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf}; use hashbrown::HashMap; use rand::distributions::Alphanumeric; use rand::rngs::StdRng; @@ -196,7 +198,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::WindowUDF(lag_udwf()), // its name "LAG", // no argument @@ -210,7 +212,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::WindowUDF(lead_udwf()), // its name "LEAD", // no argument @@ -224,9 +226,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::WindowUDF(rank_udwf()), // its name - "RANK", + "rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -238,11 +240,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), // its name - "DENSE_RANK", + "dense_rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -293,7 +293,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { vec![window_expr], memory_exec.clone(), vec![], - InputOrderMode::Linear, + Linear, )?); let task_ctx = ctx.task_ctx(); let mut collected_results = @@ -382,28 +382,19 @@ fn get_random_function( ); window_fn_map.insert( "rank", - ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Rank, - ), - vec![], - ), + (WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![]), ); window_fn_map.insert( "dense_rank", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lead, - ), + WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -414,9 +405,7 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lag, - ), + WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -603,7 +592,7 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, InputOrderMode::Sorted); + let is_linear = !matches!(search_mode, Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); @@ -654,7 +643,7 @@ async fn run_window_test( ]; let mut exec1 = Arc::new( MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. @@ -680,7 +669,7 @@ async fn run_window_test( )?) as _; let exec2 = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 72ac6e64fb0c5..c35e46c0c558f 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -37,3 +37,13 @@ mod plan_datafusion_err { plan_datafusion_err!("foo"); } } + +mod record_batch { + // NO other imports! + use datafusion_common::record_batch; + + #[test] + fn test_macro() { + record_batch!(("column_name", Int32, vec![1, 2, 3])).unwrap(); + } +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index ec66df45c7baa..fc2fb9afb5f93 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -840,7 +840,7 @@ impl TableProvider for SortedTableProvider { ) -> Result> { let mem_exec = MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())? - .with_sort_information(self.sort_information.clone()); + .try_with_sort_information(self.sort_information.clone())?; Ok(Arc::new(mem_exec)) } diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 18d8300fb254d..4b5d22bfa71ff 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,7 +28,6 @@ use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, }; @@ -211,8 +210,8 @@ fn get_cache_runtime_state() -> ( SessionState, ) { let cache_config = CacheManagerConfig::default(); - let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); - let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + let file_static_cache = Arc::new(DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(DefaultListFilesCache::default()); let cache_config = cache_config .with_files_statistics_cache(Some(file_static_cache.clone())) diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs deleted file mode 100644 index bbf4dcd2b799d..0000000000000 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ /dev/null @@ -1,325 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tests for the physical optimizer - -use datafusion_common::config::ConfigOptions; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; -use std::sync::Arc; - -use datafusion::error::Result; -use datafusion::logical_expr::Operator; -use datafusion::prelude::SessionContext; -use datafusion::test_util::TestAggregate; -use datafusion_physical_plan::aggregates::PhysicalGroupBy; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::common; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::memory::MemoryExec; - -use arrow::array::Int32Array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_int64_array; -use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_plan::aggregates::AggregateMode; - -/// Mock data using a MemoryExec which has an exact count statistic -fn mock_data() -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), - Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), - ], - )?; - - Ok(Arc::new(MemoryExec::try_new( - &[vec![batch]], - Arc::clone(&schema), - None, - )?)) -} - -/// Checks that the count optimization was applied and we still get the right result -async fn assert_count_optim_success( - plan: AggregateExec, - agg: TestAggregate, -) -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let plan: Arc = Arc::new(plan); - - let optimized = - AggregateStatistics::new().optimize(Arc::clone(&plan), state.config_options())?; - - // A ProjectionExec is a sign that the count optimization was applied - assert!(optimized.as_any().is::()); - - // run both the optimized and nonoptimized plan - let optimized_result = - common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = - common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; - assert_eq!(optimized_result.len(), nonoptimized_result.len()); - - // and validate the results are the same and expected - assert_eq!(optimized_result.len(), 1); - check_batch(optimized_result.into_iter().next().unwrap(), &agg); - // check the non optimized one too to ensure types and names remain the same - assert_eq!(nonoptimized_result.len(), 1); - check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); - - Ok(()) -} - -fn check_batch(batch: RecordBatch, agg: &TestAggregate) { - let schema = batch.schema(); - let fields = schema.fields(); - assert_eq!(fields.len(), 1); - - let field = &fields[0]; - assert_eq!(field.name(), agg.column_name()); - assert_eq!(field.data_type(), &DataType::Int64); - // note that nullabiolity differs - - assert_eq!( - as_int64_array(batch.column(0)).unwrap().values(), - &[agg.expected_count()] - ); -} - -#[tokio::test] -async fn test_count_partial_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} - -#[tokio::test] -async fn test_count_with_nulls_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 24e46b3ad97c7..85076abdaf299 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -84,7 +84,7 @@ fn parquet_exec(schema: &SchemaRef) -> Arc { fn partial_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -104,7 +104,7 @@ fn partial_aggregate_exec( fn final_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -130,11 +130,12 @@ fn count_expr( expr: Arc, name: &str, schema: &Schema, -) -> AggregateFunctionExpr { +) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) .schema(Arc::new(schema.clone())) .alias(name) .build() + .map(Arc::new) .unwrap() } @@ -218,6 +219,7 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { .schema(Arc::clone(&schema)) .alias("Sum(b)") .build() + .map(Arc::new) .unwrap(), ]; let groups: Vec<(Arc, String)> = diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 042f6d622565c..6859e2f1468ce 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -347,10 +347,10 @@ fn test_has_aggregate_expression() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -375,7 +375,7 @@ fn test_has_filter() -> Result<()> { // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec let filter_expr = Some(expressions::binary( - expressions::col("a", &schema)?, + col("a", &schema)?, Operator::Gt, cast(expressions::lit(1u32), &schema, DataType::Int32)?, &schema, @@ -384,10 +384,10 @@ fn test_has_filter() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -408,7 +408,7 @@ fn test_has_filter() -> Result<()> { #[test] fn test_has_order_by() -> Result<()> { let sort_key = vec![PhysicalSortExpr { - expr: expressions::col("a", &schema()).unwrap(), + expr: col("a", &schema()).unwrap(), options: SortOptions::default(), }]; let source = parquet_exec_with_sort(vec![sort_key]); diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 4ec981bf2a741..c06783aa0277b 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -mod aggregate_statistics; mod combine_partial_final_agg; mod limit_pushdown; mod limited_distinct_aggregation; diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index addabc8a36127..fab92c0f9c2bf 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -33,7 +33,7 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; @@ -101,7 +101,7 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index dc9d047860213..177427b47d218 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -65,7 +65,7 @@ pub mod select; mod sql_api; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let df = ctx .sql(&format!( @@ -103,7 +103,7 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -227,7 +227,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { } async fn register_alltypes_parquet(ctx: &SessionContext) { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index b99bc26800449..252d76d0f9d92 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -229,9 +229,6 @@ async fn tpcds_logical_q40() -> Result<()> { } #[tokio::test] -#[ignore] -// Optimizer rule 'scalar_subquery_to_join' failed: Optimizing disjunctions not supported! -// issue: https://github.com/apache/datafusion/issues/5368 async fn tpcds_logical_q41() -> Result<()> { create_logical_plan(41).await } @@ -571,7 +568,6 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -697,7 +693,6 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -728,8 +723,6 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] -// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await @@ -750,7 +743,6 @@ async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q45() -> Result<()> { create_physical_plan(45).await diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1b23bf9ab2ef5..ad9c1280d6b11 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -29,6 +29,7 @@ use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion_expr::BinaryExpr; +#[derive(Debug)] struct MyCustomPlanner; impl ExprPlanner for MyCustomPlanner { diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs new file mode 100644 index 0000000000000..ff14fa0be3fb6 --- /dev/null +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::{ + error::Result, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::{DisplayAs, ExecutionMode, ExecutionPlan, PlanProperties}; + +#[tokio::test] +async fn insert_operation_is_passed_correctly_to_table_provider() { + // Use the SQLite syntax so we can test the "INSERT OR REPLACE INTO" syntax + let ctx = session_ctx_with_dialect("SQLite"); + let table_provider = Arc::new(TestInsertTableProvider::new()); + ctx.register_table("testing", table_provider.clone()) + .unwrap(); + + let sql = "INSERT INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Append).await; + + let sql = "INSERT OVERWRITE testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Overwrite).await; + + let sql = "REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; + + let sql = "INSERT OR REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; +} + +async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) { + let df = ctx.sql(sql).await.unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + let exec = plan.as_any().downcast_ref::().unwrap(); + assert_eq!(exec.op, insert_op); +} + +fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { + let mut config = SessionConfig::new(); + let options = config.options_mut(); + options.sql_parser.dialect = dialect.into(); + SessionContext::new_with_config(config) +} + +#[derive(Debug)] +struct TestInsertTableProvider { + schema: SchemaRef, +} + +impl TestInsertTableProvider { + fn new() -> Self { + Self { + schema: SchemaRef::new(Schema::new(vec![Field::new( + "column", + DataType::Int64, + false, + )])), + } + } +} + +#[async_trait] +impl TableProvider for TestInsertTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!("TestInsertTableProvider is a stub for testing.") + } + + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + insert_op: InsertOp, + ) -> Result> { + Ok(Arc::new(TestInsertExec::new(insert_op))) + } +} + +#[derive(Debug)] +struct TestInsertExec { + op: InsertOp, + plan_properties: PlanProperties, +} + +impl TestInsertExec { + fn new(op: InsertOp) -> Self { + let eq_properties = EquivalenceProperties::new(make_count_schema()); + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { + op, + plan_properties, + } + } +} + +impl DisplayAs for TestInsertExec { + fn fmt_as( + &self, + _t: datafusion_physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestInsertExec") + } +} + +impl ExecutionPlan for TestInsertExec { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } +} + +fn make_count_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +} diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 56cec8df468b0..5d84cdb692830 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -32,3 +32,6 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; + +/// Tests for insert operations +mod insert_operation; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 1e0d3d9d514e8..497addd23094a 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -747,7 +747,7 @@ impl Accumulator for FirstSelector { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -816,7 +816,7 @@ impl Accumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } fn state(&mut self) -> Result> { @@ -864,6 +864,6 @@ impl GroupsAccumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index caf639434a999..c962567844022 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -81,7 +81,7 @@ use datafusion::{ runtime_env::RuntimeEnv, }, logical_expr::{ - Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, + Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, optimizer::{OptimizerConfig, OptimizerRule}, @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::tree_node::replace_sort_expression; -use datafusion_expr::{Projection, SortExpr}; +use datafusion_expr::{FetchType, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -312,6 +312,7 @@ fn make_topk_context() -> SessionContext { // ------ The implementation of the TopK code follows ----- +#[derive(Debug)] struct TopKQueryPlanner {} #[async_trait] @@ -360,28 +361,28 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - if let LogicalPlan::Limit(Limit { - fetch: Some(fetch), - input, + let LogicalPlan::Limit(ref limit) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort(Sort { + ref expr, + ref input, .. - }) = &plan + }) = limit.input.as_ref() { - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = **input - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: *fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); - } + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); } } @@ -442,6 +443,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } /// Physical planner for TopK nodes @@ -508,11 +513,7 @@ impl Debug for TopKExec { } impl DisplayAs for TopKExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "TopKExec: k={}", self.k) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 013aec48d5108..f1b1728623998 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,9 +16,11 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::array::as_string_array; use arrow::compute::kernels::numeric::add; use arrow_array::builder::BooleanBuilder; use arrow_array::cast::AsArray; @@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { Ok(()) } +/// Volatile UDF that should append a different value to each row +#[derive(Debug)] +struct AddIndexToStringVolatileScalarUDF { + name: String, + signature: Signature, + return_type: DataType, +} + +impl AddIndexToStringVolatileScalarUDF { + fn new() -> Self { + Self { + name: "add_index_to_string".to_string(), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + return_type: DataType::Utf8, + } + } +} + +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("index_with_offset function does not accept arguments") + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + let answer = match &args[0] { + // When called with static arguments, the result is returned as an array. + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + let mut answer = vec![]; + for index in 1..=number_rows { + // When calling a function with immutable arguments, the result is returned with ")". + // Example: SELECT add_index_to_string('const_value') FROM table; + answer.push(index.to_string() + ") " + value); + } + answer + } + // The result is returned as an array when called with dynamic arguments. + ColumnarValue::Array(array) => { + let string_array = as_string_array(array); + let mut counter = HashMap::<&str, u64>::new(); + string_array + .iter() + .map(|value| { + let value = value.expect("Unexpected null"); + let index = counter.get(value).unwrap_or(&0) + 1; + counter.insert(value, index); + + // When calling a function with mutable arguments, the result is returned with ".". + // Example: SELECT add_index_to_string(table.value) FROM table; + index.to_string() + ". " + value + }) + .collect() + } + _ => unimplemented!(), + }; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) + } +} + +#[tokio::test] +async fn volatile_scalar_udf_with_params() -> Result<()> { + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters + .await?; + let expected = [ + "+-----------+", + "| str |", + "+-----------+", + "| 1. test_1 |", + "| 2. test_1 |", + "| 3. test_1 |", + "| 1. test_2 |", + "| 2. test_2 |", + "| 4. test_1 |", + "| 3. test_2 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters + .await?; + let expected = [ + "+---------+", + "| str |", + "+---------+", + "| 1) test |", + "| 2) test |", + "| 3) test |", + "| 4) test |", + "| 5) test |", + "| 6) test |", + "| 7) test |", + "+---------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters + .await?; + let expected = [ + "+---------------+", + "| str |", + "+---------------+", + "| 1) test_value |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + } + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") + .await?; + let expected = [ + "+-----------+", // + "| str |", // + "+-----------+", // + "| 1. test_1 |", // + "| 2. test_1 |", // + "| 3. test_1 |", // + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + } + Ok(()) +} + #[derive(Debug)] struct CastToI64UDF { signature: Signature, @@ -755,11 +936,11 @@ struct ScalarFunctionWrapper { name: String, expr: Expr, signature: Signature, - return_type: arrow_schema::DataType, + return_type: DataType, } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -767,21 +948,15 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } - fn return_type( - &self, - _arg_types: &[arrow_schema::DataType], - ) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - fn invoke( - &self, - _args: &[datafusion_expr::ColumnarValue], - ) -> Result { + fn invoke(&self, _args: &[ColumnarValue]) -> Result { internal_err!("This function should not get invoked!") } @@ -861,10 +1036,7 @@ impl TryFrom for ScalarFunctionWrapper { .into_iter() .map(|a| a.data_type) .collect(), - definition - .params - .behavior - .unwrap_or(datafusion_expr::Volatility::Volatile), + definition.params.behavior.unwrap_or(Volatility::Volatile), ), }) } @@ -1169,7 +1341,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -1181,7 +1353,7 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { } async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index fe57752db52ea..0cc156866d4d1 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -192,6 +192,7 @@ impl SimpleCsvTable { } } +#[derive(Debug)] struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index d96bb23953aee..8fe028eedd443 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -36,6 +36,7 @@ use datafusion_expr::{ PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -552,7 +553,10 @@ impl OddCounter { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) } @@ -589,11 +593,7 @@ impl PartitionEvaluator for OddCounter { Ok(scalar) } - fn evaluate_all( - &mut self, - values: &[arrow_array::ArrayRef], - num_rows: usize, - ) -> Result { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { println!("evaluate_all, values: {values:#?}, num_rows: {num_rows}"); self.test_state.inc_evaluate_all_called(); @@ -637,7 +637,7 @@ fn odd_count(arr: &Int64Array) -> i64 { } /// returns an array of num_rows that has the number of odd values in `arr` -fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca47..53646dc5b468e 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -432,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index c98d7e5579f0f..38c259fcbdc8e 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -173,7 +173,7 @@ fn create_local_dirs(local_dirs: Vec) -> Result>> { local_dirs .iter() .map(|root| { - if !std::path::Path::new(root).exists() { + if !Path::new(root).exists() { std::fs::create_dir(root)?; } Builder::new() diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index dcd59acbd49eb..5bf30b724d0b9 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -68,11 +68,35 @@ pub use pool::*; /// Note that a `MemoryPool` can be shared by concurrently executing plans, /// which can be used to control memory usage in a multi-tenant system. /// +/// # How MemoryPool works by example +/// +/// Scenario 1: +/// For `Filter` operator, `RecordBatch`es will stream through it, so it +/// don't have to keep track of memory usage through [`MemoryPool`]. +/// +/// Scenario 2: +/// For `CrossJoin` operator, if the input size gets larger, the intermediate +/// state will also grow. So `CrossJoin` operator will use [`MemoryPool`] to +/// limit the memory usage. +/// 2.1 `CrossJoin` operator has read a new batch, asked memory pool for +/// additional memory. Memory pool updates the usage and returns success. +/// 2.2 `CrossJoin` has read another batch, and tries to reserve more memory +/// again, memory pool does not have enough memory. Since `CrossJoin` operator +/// has not implemented spilling, it will stop execution and return an error. +/// +/// Scenario 3: +/// For `Aggregate` operator, its intermediate states will also accumulate as +/// the input size gets larger, but with spilling capability. When it tries to +/// reserve more memory from the memory pool, and the memory pool has already +/// reached the memory limit, it will return an error. Then, `Aggregate` +/// operator will spill the intermediate buffers to disk, and release memory +/// from the memory pool, and continue to retry memory reservation. +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the /// [`MemoryPool`] trait and configuring a `SessionContext` appropriately. -/// However, mDataFusion comes with the following simple memory pool implementations that +/// However, DataFusion comes with the following simple memory pool implementations that /// handle many common cases: /// /// * [`UnboundedMemoryPool`]: no memory limits (the default) @@ -310,13 +334,17 @@ impl Drop for MemoryReservation { } } -const TB: u64 = 1 << 40; -const GB: u64 = 1 << 30; -const MB: u64 = 1 << 20; -const KB: u64 = 1 << 10; +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} /// Present size in human readable form pub fn human_readable_size(size: usize) -> String { + use units::*; + let size = size as u64; let (value, unit) = { if size >= 2 * TB { diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index 7fc5e458b86b5..f3eb7b77e03cc 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -20,7 +20,9 @@ use datafusion_common::Result; use futures::Stream; use std::pin::Pin; -/// Trait for types that stream [arrow::record_batch::RecordBatch] +/// Trait for types that stream [RecordBatch] +/// +/// See [`SendableRecordBatchStream`] for more details. pub trait RecordBatchStream: Stream> { /// Returns the schema of this `RecordBatchStream`. /// @@ -29,5 +31,23 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`] of [`RecordBatch`]es +/// Trait for a [`Stream`] of [`RecordBatch`]es that can be passed between threads +/// +/// This trait is used to retrieve the results of DataFusion execution plan nodes. +/// +/// The trait is a specialized Rust Async [`Stream`] that also knows the schema +/// of the data it will return (even if the stream has no data). Every +/// `RecordBatch` returned by the stream should have the same schema as returned +/// by [`schema`](`RecordBatchStream::schema`). +/// +/// # Error Handling +/// +/// Once a stream returns an error, it should not be polled again (the caller +/// should stop calling `next`) and handle the error. +/// +/// However, returning `Ready(None)` (end of stream) is likely the safest +/// behavior after an error. Like [`Stream`]s, `RecordBatchStream`s should not +/// be polled after end of stream or returning an error. However, also like +/// [`Stream`]s there is no mechanism to prevent callers polling so returning +/// `Ready(None)` is recommended. pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 7e477efc4ebc1..de11b19c3b06b 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -40,4 +40,5 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +itertools = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 75335209451e1..7155c7993f8c9 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -39,7 +39,7 @@ use std::fmt::Debug; /// function]) /// /// * convert its internal state to a vector of aggregate values via -/// [`state`] and combine the state from multiple accumulators' +/// [`state`] and combine the state from multiple accumulators /// via [`merge_batch`], as part of efficient multi-phase grouping. /// /// [`GroupsAccumulator`]: crate::GroupsAccumulator @@ -68,7 +68,7 @@ pub trait Accumulator: Send + Sync + Debug { /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; @@ -89,14 +89,14 @@ pub trait Accumulator: Send + Sync + Debug { /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings). /// /// Intermediate state is used for "multi-phase" grouping in /// DataFusion, where an aggregate is computed in parallel with /// multiple `Accumulator` instances, as described below: /// - /// # MultiPhase Grouping + /// # Multi-Phase Grouping /// /// ```text /// ▲ @@ -140,9 +140,9 @@ pub trait Accumulator: Send + Sync + Debug { /// to be summed together) /// /// Some accumulators can return multiple values for their - /// intermediate states. For example average, tracks `sum` and - /// `n`, and this function should return - /// a vector of two values, sum and n. + /// intermediate states. For example, the average accumulator + /// tracks `sum` and `n`, and this function should return a vector + /// of two values, sum and n. /// /// Note that [`ScalarValue::List`] can be used to pass multiple /// values if the number of intermediate values is not known at @@ -204,7 +204,7 @@ pub trait Accumulator: Send + Sync + Debug { /// The final output is computed by repartitioning the result of /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so /// that each distinct group key appears in exactly one of the - /// `AggregateMode::Final` GroupBy nodes. The output of the final nodes are + /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are /// then unioned together to produce the overall final output. /// /// Here is an example that shows the distribution of groups in the diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index bfefb37c98d75..4b9454ed739d7 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,10 +17,9 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; -use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::{kernels, CastOptions}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; @@ -130,7 +129,7 @@ impl ColumnarValue { }) } - /// null columnar values are implemented as a null array in order to pass batch + /// Null columnar values are implemented as a null array in order to pass batch /// num_rows pub fn create_null_array(num_rows: usize) -> Self { ColumnarValue::Array(Arc::new(NullArray::new(num_rows))) @@ -194,28 +193,9 @@ impl ColumnarValue { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), - ColumnarValue::Scalar(scalar) => { - let scalar_array = - if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar { - ScalarValue::Int64(Some( - (float_ts * 1_000_000_000_f64).trunc() as i64, - )) - .to_array()? - } else { - scalar.to_array()? - } - } else { - scalar.to_array()? - }; - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - cast_type, - &cast_options, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( + scalar.cast_to_with_options(cast_type, &cast_options)?, + )), } } } diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 8e81c51d8460f..2c8b126cb52ca 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -90,6 +90,11 @@ impl EmitTo { /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. /// +/// [`NullState`] can help keep the state for groups that have not seen any +/// values and produce the correct output for those groups. +/// +/// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html +/// /// # Details /// Each group is assigned a `group_index` by the hash table and each /// accumulator manages the specific state, one per `group_index`. @@ -106,8 +111,7 @@ pub trait GroupsAccumulator: Send { /// /// * `values`: the input arguments to the accumulator /// - /// * `group_indices`: To which groups do the rows in `values` - /// belong, group id) + /// * `group_indices`: The group indices to which each row in `values` belongs. /// /// * `opt_filter`: if present, only update aggregate state using /// `values[i]` if `opt_filter[i]` is true @@ -117,6 +121,11 @@ pub trait GroupsAccumulator: Send { /// /// Note that subsequent calls to update_batch may have larger /// total_num_groups as new groups are seen. + /// + /// See [`NullState`] to help keep the state for groups that have not seen any + /// values and produce the correct output for those groups. + /// + /// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html fn update_batch( &mut self, values: &[ArrayRef], @@ -175,9 +184,9 @@ pub trait GroupsAccumulator: Send { /// differ. See [`Self::state`] for more details on how state is /// used and merged. /// - /// * `values`: arrays produced from calling `state` previously to the accumulator + /// * `values`: arrays produced from previously calling `state` on other accumulators. /// - /// Other arguments are the same as for [`Self::update_batch`]; + /// Other arguments are the same as for [`Self::update_batch`]. fn merge_batch( &mut self, values: &[ArrayRef], @@ -186,7 +195,7 @@ pub trait GroupsAccumulator: Send { total_num_groups: usize, ) -> Result<()>; - /// Converts an input batch directly the intermediate aggregate state. + /// Converts an input batch directly to the intermediate aggregate state. /// /// This is the equivalent of treating each input row as its own group. It /// is invoked when the Partial phase of a multi-phase aggregation is not diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 6424888c896a5..ffaa32f08075c 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1223,8 +1223,8 @@ pub fn satisfy_greater( } } - // Only the lower bound of left hand side and the upper bound of the right - // hand side can change after propagating the greater-than operation. + // Only the lower bound of left-hand side and the upper bound of the right-hand + // side can change after propagating the greater-than operation. let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { if strict { next_value(right.lower.clone()) @@ -1753,7 +1753,7 @@ impl NullableInterval { } _ => Ok(Self::MaybeNull { values }), } - } else if op.is_comparison_operator() { + } else if op.supports_propagation() { Ok(Self::Null { datatype: DataType::Boolean, }) diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index e013b6fafa22d..6ca0f04897aca 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -142,10 +142,11 @@ impl Operator { ) } - /// Return true if the operator is a comparison operator. + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation /// - /// For example, 'Binary(a, >, b)' would be a comparison expression. - pub fn is_comparison_operator(&self) -> bool { + /// For example, 'Binary(a, >, b)' expression supports propagation. + pub fn supports_propagation(&self) -> bool { matches!( self, Operator::Eq @@ -163,6 +164,15 @@ impl Operator { ) } + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation + /// + /// For example, 'Binary(a, >, b)' expression supports propagation. + #[deprecated(since = "43.0.0", note = "please use `supports_propagation` instead")] + pub fn is_comparison_operator(&self) -> bool { + self.supports_propagation() + } + /// Return true if the operator is a logic operator. /// /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index d1553b3315e71..24cb54f634b14 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -35,7 +35,7 @@ pub const TIMEZONE_WILDCARD: &str = "+TZ"; /// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths. pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN; -///A function's volatility, which defines the functions eligibility for certain optimizations +/// A function's volatility, which defines the functions eligibility for certain optimizations #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { /// An immutable function will always return the same output when given the same @@ -86,7 +86,7 @@ pub enum Volatility { /// ``` #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { - /// One or more arguments of an common type out of a list of valid types. + /// One or more arguments of a common type out of a list of valid types. /// /// # Examples /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` @@ -125,6 +125,11 @@ pub enum TypeSignature { /// Fixed number of arguments of numeric types. /// See to know which type is considered numeric Numeric(usize), + /// Fixed number of arguments of all the same string types. + /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. + /// Null is considerd as `Utf8` by default + /// Dictionary with string value type is also handled. + String(usize), } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -190,8 +195,11 @@ impl TypeSignature { .collect::>() .join(", ")] } + TypeSignature::String(num) => { + vec![format!("String({num})")] + } TypeSignature::Numeric(num) => { - vec![format!("Numeric({})", num)] + vec![format!("Numeric({num})")] } TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] @@ -280,6 +288,14 @@ impl Signature { } } + /// A specified number of numeric arguments + pub fn string(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::String(arg_count), + volatility, + } + } + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 2add9e7c1867c..fee75f9e45959 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -143,21 +143,21 @@ pub fn check_arg_count( Ok(()) } -/// function return type of a sum +/// Function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } @@ -165,7 +165,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } -/// function return type of variance +/// Function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -174,7 +174,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { } } -/// function return type of covariance +/// Function return type of covariance pub fn covariance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -183,7 +183,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { } } -/// function return type of correlation +/// Function return type of correlation pub fn correlation_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) @@ -192,19 +192,19 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { } } -/// function return type of an average +/// Function return type of an average pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { match arg_type { DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal256(new_precision, new_scale)) @@ -217,16 +217,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result } } -/// internal sum type of an average +/// Internal sum type of an average pub fn avg_sum_type(arg_type: &DataType) -> Result { match arg_type { DataType::Decimal128(precision, scale) => { - // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } DataType::Decimal256(precision, scale) => { - // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c1e96a8fa97d6..31fe6a59baee7 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -25,10 +25,13 @@ use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result, +}; +use itertools::Itertools; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` @@ -86,7 +89,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { // Logical binary boolean operators can only be evaluated for // boolean or null arguments. - Ok(Signature::uniform(DataType::Boolean)) + Ok(Signature::uniform(Boolean)) } else { plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" @@ -191,7 +194,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } } -/// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types +/// Returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( lhs: &DataType, op: &Operator, @@ -370,17 +373,21 @@ impl From<&DataType> for TypeCategory { /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted /// decimal precision and scale when coercing decimal types. +/// +/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. +/// +/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; } - // if all the data_types is the same return first one + // If all the data_types is the same return first one if data_types.iter().all(|t| t == &data_types[0]) { return Some(data_types[0].clone()); } - // if all the data_types are null, return string + // If all the data_types are null, return string if data_types.iter().all(|t| t == &DataType::Null) { return Some(DataType::Utf8); } @@ -399,7 +406,7 @@ pub fn type_union_resolution(data_types: &[DataType]) -> Option { return None; } - // check if there is only one category excluding Unknown + // Check if there is only one category excluding Unknown let categories: HashSet = HashSet::from_iter( data_types_category .iter() @@ -471,16 +478,145 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } + (DataType::List(lhs), DataType::List(rhs)) => { + let new_item_type = + type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); + new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) + } + (DataType::Struct(lhs), DataType::Struct(rhs)) => { + if lhs.len() != rhs.len() { + return None; + } + + // Search the field in the right hand side with the SAME field name + fn search_corresponding_coerced_type( + lhs_field: &FieldRef, + rhs: &Fields, + ) -> Option { + for rhs_field in rhs.iter() { + if lhs_field.name() == rhs_field.name() { + if let Some(t) = type_union_resolution_coercion( + lhs_field.data_type(), + rhs_field.data_type(), + ) { + return Some(t); + } else { + return None; + } + } + } + + None + } + + let types = lhs + .iter() + .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs)) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(DataType::Struct(fields.into())) + } _ => { - // numeric coercion is the same as comparison coercion, both find the narrowest type + // Numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) } } } +/// Handle type union resolution including struct type and others. +pub fn try_type_union_resolution(data_types: &[DataType]) -> Result> { + let err = match try_type_union_resolution_with_struct(data_types) { + Ok(struct_types) => return Ok(struct_types), + Err(e) => Some(e), + }; + + if let Some(new_type) = type_union_resolution(data_types) { + Ok(vec![new_type; data_types.len()]) + } else { + exec_err!("Fail to find the coerced type, errors: {:?}", err) + } +} + +// Handle struct where we only change the data type but preserve the field name and nullability. +// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" +pub fn try_type_union_resolution_with_struct( + data_types: &[DataType], +) -> Result> { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + } + } else { + keys_string = Some(keys); + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut struct_types: Vec = if let DataType::Struct(fields) = &data_types[0] + { + fields.iter().map(|f| f.data_type().to_owned()).collect() + } else { + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + for data_type in data_types.iter().skip(1) { + if let DataType::Struct(fields) = data_type { + let incoming_struct_types: Vec = + fields.iter().map(|f| f.data_type().to_owned()).collect(); + // The order of field is verified above + for (lhs_type, rhs_type) in + struct_types.iter_mut().zip(incoming_struct_types.iter()) + { + if let Some(coerced_type) = + type_union_resolution_coercion(lhs_type, rhs_type) + { + *lhs_type = coerced_type; + } else { + return exec_err!( + "Fail to find the coerced type for {} and {}", + lhs_type, + rhs_type + ); + } + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut final_struct_types = vec![]; + for s in data_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(struct_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + + Ok(final_struct_types) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a /// comparison operation /// @@ -507,22 +643,6 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - if lhs_type == rhs_type { - // same type => equality is possible - return Some(lhs_type.clone()); - } - binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) - .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| binary_coercion(lhs_type, rhs_type)) -} - /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is numeric and one is `Utf8`/`LargeUtf8`. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -598,7 +718,7 @@ pub fn binary_numeric_coercion( return Some(t); } - // these are ordered from most informative to least informative so + // These are ordered from most informative to least informative so // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { (Float64, _) | (_, Float64) => Some(Float64), @@ -824,12 +944,12 @@ fn mathematics_numerical_coercion( ) -> Option { use arrow::datatypes::DataType::*; - // error on any non-numeric type + // Error on any non-numeric type if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { return None; }; - // these are ordered from most informative to least informative so + // These are ordered from most informative to least informative so // that the coercion removes the least amount of information match (lhs_type, rhs_type) { (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { @@ -969,7 +1089,7 @@ fn string_concat_internal_coercion( /// based on the observation that StringArray to StringViewArray is cheap but not vice versa. /// /// Between Utf8 and LargeUtf8, we coerce to LargeUtf8. -fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { // If Utf8View is in any side, we coerce to Utf8View. @@ -1052,12 +1172,16 @@ fn binary_to_string_coercion( match (lhs_type, rhs_type) { (Binary, Utf8) => Some(Utf8), (Binary, LargeUtf8) => Some(LargeUtf8), + (BinaryView, Utf8) => Some(Utf8View), + (BinaryView, LargeUtf8) => Some(LargeUtf8), (LargeBinary, Utf8) => Some(LargeUtf8), (LargeBinary, LargeUtf8) => Some(LargeUtf8), (Utf8, Binary) => Some(Utf8), (Utf8, LargeBinary) => Some(LargeUtf8), + (Utf8, BinaryView) => Some(Utf8View), (LargeUtf8, Binary) => Some(LargeUtf8), (LargeUtf8, LargeBinary) => Some(LargeUtf8), + (LargeUtf8, BinaryView) => Some(LargeUtf8), _ => None, } } @@ -1086,7 +1210,7 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } -/// coercion rules for like operations. +/// Coercion rules for like operations. /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) @@ -1097,13 +1221,13 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (DataType::Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), - (Utf8View | Utf8 | LargeUtf8, DataType::Null) => Some(lhs_type.clone()), - (DataType::Null, DataType::Null) => Some(Utf8), + (Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), + (Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()), + (Null, Null) => Some(Utf8), _ => None, } } @@ -1259,7 +1383,7 @@ fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit { } } -/// coercion rules from NULL type. Since NULL can be casted to any other type in arrow, +/// Coercion rules from NULL type. Since NULL can be casted to any other type in arrow, /// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid. fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 55387fea22eeb..d7dc1afe4d505 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -48,6 +48,7 @@ datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +indexmap = { workspace = true } paste = "^1.0" serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index b136d6cacec8f..ab41395ad371c 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::type_coercion::functions::data_types; use crate::utils; -use crate::{Signature, TypeSignature, Volatility}; +use crate::{Signature, Volatility}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use arrow::datatypes::DataType; @@ -37,34 +37,14 @@ impl fmt::Display for BuiltInWindowFunction { /// A [window function] built in to DataFusion /// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +/// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, /// returns value evaluated at the row that is the first row of the window frame FirstValue, - /// returns value evaluated at the row that is the last row of the window frame + /// Returns value evaluated at the row that is the last row of the window frame LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + /// Returns value evaluated at the row that is the nth row of the window frame (counting from 1); returns null if no such row NthValue, } @@ -72,13 +52,6 @@ impl BuiltInWindowFunction { pub fn name(&self) -> &str { use BuiltInWindowFunction::*; match self { - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", FirstValue => "first_value", LastValue => "last_value", NthValue => "NTH_VALUE", @@ -90,13 +63,6 @@ impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name.to_uppercase().as_str() { - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, @@ -111,10 +77,10 @@ impl BuiltInWindowFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function + // Verify that this is a valid set of data types for this function data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message + // Original errors are all related to wrong function signature + // Aggregate them for better error message .map_err(|_| { plan_datafusion_err!( "{}", @@ -127,55 +93,19 @@ impl BuiltInWindowFunction { })?; match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } } - /// the signatures supported by the built-in window function `fun`. + /// The signatures supported by the built-in window function `fun`. pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. + // Note: The physical expression must accept the type returned by this function or the execution panics. match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 7a2bf4b6c44a0..23cc88f1c0ff3 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -64,7 +64,7 @@ impl CaseBuilder { } fn build(&self) -> Result { - // collect all "then" expressions + // Collect all "then" expressions let mut then_expr = self.then_expr.clone(); if let Some(e) = &self.else_expr { then_expr.push(e.as_ref().to_owned()); @@ -79,7 +79,7 @@ impl CaseBuilder { .collect::>>()?; if then_types.contains(&DataType::Null) { - // cannot verify types until execution type + // Cannot verify types until execution type } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 02a2edb98016d..bda4d7ae3d7fa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,11 +29,12 @@ use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::Volatility; use crate::{ - built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, - Signature, WindowFrame, WindowUDF, + udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, + WindowUDF, }; use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -694,11 +695,11 @@ impl AggregateFunction { pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function /// A a built-in window function - BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + BuiltInWindowFunction(BuiltInWindowFunction), /// A user defined aggregate function AggregateUDF(Arc), /// A user defined aggregate function - WindowUDF(Arc), + WindowUDF(Arc), } impl WindowFunctionDefinition { @@ -722,7 +723,7 @@ impl WindowFunctionDefinition { } } - /// the signatures supported by the function `fun`. + /// The signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), @@ -741,14 +742,12 @@ impl WindowFunctionDefinition { } } -impl fmt::Display for WindowFunctionDefinition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - std::fmt::Display::fmt(fun, f) - } - WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f), - WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f), } } } @@ -832,9 +831,7 @@ pub fn find_df_window_func(name: &str) -> Option { // may have different implementations for these cases. If the sought // function is not found among built-in window functions, we search for // it among aggregate functions. - if let Ok(built_in_function) = - built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) - { + if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { Some(WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, )) @@ -846,7 +843,7 @@ pub fn find_df_window_func(name: &str) -> Option { /// EXISTS expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Exists { - /// subquery that will produce a single column of data + /// Subquery that will produce a single column of data pub subquery: Subquery, /// Whether the expression is negated pub negated: bool, @@ -1329,7 +1326,7 @@ impl Expr { expr, Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) ) { - // subqueries could contain aliases so don't recurse into those + // Subqueries could contain aliases so don't recurse into those TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue @@ -1346,7 +1343,7 @@ impl Expr { } }, ) - // unreachable code: internal closure doesn't return err + // Unreachable code: internal closure doesn't return err .unwrap() } @@ -1416,7 +1413,7 @@ impl Expr { )) } - /// return `self NOT BETWEEN low AND high` + /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { Expr::Between(Between::new( Box::new(self), @@ -1652,47 +1649,39 @@ impl Expr { | Expr::Placeholder(..) => false, } } +} - /// Hashes the direct content of an `Expr` without recursing into its children. - /// - /// This method is useful to incrementally compute hashes, such as in - /// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants - /// during the bottom-up phase of the first traversal and so avoid computing the hash - /// of the node and then the hash of its descendants separately. - /// - /// If a node doesn't have any children then this method is similar to `.hash()`, but - /// not necessarily returns the same value. - /// +impl HashNode for Expr { /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes /// compile time. - pub fn hash_node(&self, hasher: &mut H) { - mem::discriminant(self).hash(hasher); + fn hash_node(&self, state: &mut H) { + mem::discriminant(self).hash(state); match self { Expr::Alias(Alias { expr: _expr, relation, name, }) => { - relation.hash(hasher); - name.hash(hasher); + relation.hash(state); + name.hash(state); } Expr::Column(column) => { - column.hash(hasher); + column.hash(state); } Expr::ScalarVariable(data_type, name) => { - data_type.hash(hasher); - name.hash(hasher); + data_type.hash(state); + name.hash(state); } Expr::Literal(scalar_value) => { - scalar_value.hash(hasher); + scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { left: _left, op, right: _right, }) => { - op.hash(hasher); + op.hash(state); } Expr::Like(Like { negated, @@ -1708,9 +1697,9 @@ impl Expr { escape_char, case_insensitive, }) => { - negated.hash(hasher); - escape_char.hash(hasher); - case_insensitive.hash(hasher); + negated.hash(state); + escape_char.hash(state); + case_insensitive.hash(state); } Expr::Not(_expr) | Expr::IsNotNull(_expr) @@ -1728,7 +1717,7 @@ impl Expr { low: _low, high: _high, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Case(Case { expr: _expr, @@ -1743,10 +1732,10 @@ impl Expr { expr: _expr, data_type, }) => { - data_type.hash(hasher); + data_type.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { - func.hash(hasher); + func.hash(state); } Expr::AggregateFunction(AggregateFunction { func, @@ -1756,9 +1745,9 @@ impl Expr { order_by: _order_by, null_treatment, }) => { - func.hash(hasher); - distinct.hash(hasher); - null_treatment.hash(hasher); + func.hash(state); + distinct.hash(state); + null_treatment.hash(state); } Expr::WindowFunction(WindowFunction { fun, @@ -1768,56 +1757,56 @@ impl Expr { window_frame, null_treatment, }) => { - fun.hash(hasher); - window_frame.hash(hasher); - null_treatment.hash(hasher); + fun.hash(state); + window_frame.hash(state); + null_treatment.hash(state); } Expr::InList(InList { expr: _expr, list: _list, negated, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Exists(Exists { subquery, negated }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::InSubquery(InSubquery { expr: _expr, subquery, negated, }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::ScalarSubquery(subquery) => { - subquery.hash(hasher); + subquery.hash(state); } Expr::Wildcard { qualifier, options } => { - qualifier.hash(hasher); - options.hash(hasher); + qualifier.hash(state); + options.hash(state); } Expr::GroupingSet(grouping_set) => { - mem::discriminant(grouping_set).hash(hasher); + mem::discriminant(grouping_set).hash(state); match grouping_set { GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} GroupingSet::GroupingSets(_exprs) => {} } } Expr::Placeholder(place_holder) => { - place_holder.hash(hasher); + place_holder.hash(state); } Expr::OuterReferenceColumn(data_type, column) => { - data_type.hash(hasher); - column.hash(hasher); + data_type.hash(state); + column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} }; } } -// modifies expr if it is a placeholder with datatype of right +// Modifies expr if it is a placeholder with datatype of right fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { if data_type.is_none() { @@ -1890,7 +1879,7 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - // expr is not shown since it is aliased + // Expr is not shown since it is aliased Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -1945,7 +1934,7 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "END") } - // cast expr is not shown to be consistant with Postgres and Spark + // Cast expr is not shown to be consistant with Postgres and Spark Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { write!(f, "{}", SchemaDisplay(expr)) } @@ -2148,8 +2137,8 @@ pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Display for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for Expr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), @@ -2353,7 +2342,7 @@ impl fmt::Display for Expr { } fn fmt_function( - f: &mut fmt::Formatter, + f: &mut Formatter, fun: &str, distinct: bool, args: &[Expr], @@ -2415,7 +2404,7 @@ mod test { let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); - // note that CAST intentionally has a name that is different from its `Display` + // Note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. assert_eq!("Float32(1.23)", expr.schema_name().to_string()); Ok(()) @@ -2560,30 +2549,6 @@ mod test { Ok(()) } - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); @@ -2598,47 +2563,9 @@ mod test { Ok(()) } - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16], &[true], "")?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - #[test] fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - ]; + let names = vec!["first_value", "last_value", "nth_value"]; for name in names { let fun = find_df_window_func(name).unwrap(); let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); @@ -2654,34 +2581,16 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::CumeDist - )) - ); assert_eq!( find_df_window_func("first_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue )) ); assert_eq!( find_df_window_func("LAST_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lead + BuiltInWindowFunction::LastValue )) ); assert_eq!(find_df_window_func("not_exist"), None) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2975e36488dca..7fd4e64e0e627 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -27,8 +27,8 @@ use crate::function::{ }; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -39,6 +39,7 @@ use arrow::compute::kernels::cast_utils::{ use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; @@ -658,7 +659,10 @@ impl WindowUDFImpl for SimpleWindowUDF { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { (self.partition_evaluator_factory)() } @@ -697,7 +701,6 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// # use datafusion_expr::test::function_stub::count; /// # use sqlparser::ast::NullTreatment; /// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; -/// # use datafusion_expr::window_function::percent_rank; /// # // first_value is an aggregate function in another crate /// # fn first_value(_arg: Expr) -> Expr { /// unimplemented!() } @@ -717,6 +720,9 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// // Create a window expression for percent rank partitioned on column a /// // equivalent to: /// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// // percent_rank is an udwf function in another crate +/// # fn percent_rank() -> Expr { +/// unimplemented!() } /// let window = percent_rank() /// .partition_by(vec![col("a")]) /// .order_by(vec![col("b").sort(true, true)]) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 15930914dd59a..d6d5c3e2931c8 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -306,9 +306,15 @@ impl NamePreserver { /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan pub fn new(plan: &LogicalPlan) -> Self { Self { - // The schema of Filter and Join nodes comes from their inputs rather than their output expressions, - // so there is no need to use aliases to preserve expression names. - use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + // The expressions of these plans do not contribute to their output schema, + // so there is no need to preserve expression names to prevent a schema change. + use_alias: !matches!( + plan, + LogicalPlan::Filter(_) + | LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Limit(_) + ), } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ad617c53d6178..07a36672f2722 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -35,27 +35,27 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; use std::sync::Arc; -/// trait to allow expr to typable with respect to a schema +/// Trait to allow expr to typable with respect to a schema pub trait ExprSchemable { - /// given a schema, return the type of the expr + /// Given a schema, return the type of the expr fn get_type(&self, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the nullability of the expr + /// Given a schema, return the nullability of the expr fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; - /// given a schema, return the expr's optional metadata + /// Given a schema, return the expr's optional metadata fn metadata(&self, schema: &dyn ExprSchema) -> Result>; - /// convert to a field with respect to a schema + /// Convert to a field with respect to a schema fn to_field( &self, input_schema: &dyn ExprSchema, ) -> Result<(Option, Arc)>; - /// cast to a type with respect to a schema + /// Cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the type and nullability of the expr + /// Given a schema, return the type and nullability of the expr fn data_type_and_nullable(&self, schema: &dyn ExprSchema) -> Result<(DataType, bool)>; } @@ -150,7 +150,7 @@ impl ExprSchemable for Expr { .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) .map_err(|err| { plan_datafusion_err!( @@ -164,7 +164,7 @@ impl ExprSchemable for Expr { ) })?; - // perform additional function arguments validation (due to limited + // Perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } @@ -223,7 +223,7 @@ impl ExprSchemable for Expr { } Expr::Wildcard { .. } => Ok(DataType::Null), Expr::GroupingSet(_) => { - // grouping sets do not really have a type and do not appear in projections + // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } } @@ -279,7 +279,7 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), Expr::Case(case) => { - // this expression is nullable if any of the input expressions are nullable + // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr .iter() @@ -336,7 +336,7 @@ impl ExprSchemable for Expr { } Expr::Wildcard { .. } => Ok(false), Expr::GroupingSet(_) => { - // grouping sets do not really have the concept of nullable and do not appear + // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) } @@ -439,7 +439,7 @@ impl ExprSchemable for Expr { return Ok(self); } - // TODO(kszucs): most of the operations do not validate the type correctness + // TODO(kszucs): Most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -526,7 +526,14 @@ impl Expr { } } -/// cast subquery in InSubquery/ScalarSubquery to a given type. +/// Cast subquery in InSubquery/ScalarSubquery to a given type. +/// +/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific +/// columns), it casts the first expression in the projection to the target type and creates a +/// new projection with the casted expression. +/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan +/// with the casted first column. +/// pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 9814d16ddfa36..23ffc83e3549c 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -27,7 +27,9 @@ pub use datafusion_functions_aggregate_common::accumulator::{ AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, }; +pub use datafusion_functions_window_common::expr::ExpressionArgs; pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; +pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; #[derive(Debug, Clone, Copy)] pub enum Hint { @@ -67,7 +69,7 @@ pub type StateTypeFunction = /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +///Cclosure returns simplified [Expr] or an error. pub type AggregateFunctionSimplification = Box< dyn Fn( crate::expr::AggregateFunction, @@ -80,7 +82,7 @@ pub type AggregateFunctionSimplification = Box< /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +/// Closure returns simplified [Expr] or an error. pub type WindowFunctionSimplification = Box< dyn Fn( crate::expr::WindowFunction, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 260065f69af98..849d9604808ca 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -34,6 +34,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; +mod udf_docs; mod udwf; pub mod conditional_expressions; @@ -90,9 +91,12 @@ pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; -pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::{WindowUDF, WindowUDFImpl}; +pub use udaf::{ + aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, +}; +pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl}; +pub use udf_docs::{DocSection, Documentation, DocumentationBuilder}; +pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index ad96f6a85d0e5..2547aa23d3cdf 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -30,31 +30,33 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; -use crate::type_coercion::binary::values_coercion; use crate::utils::{ can_hash, columnize_expr, compare_sort_expr, expr_to_columns, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, WriteOp, }; +use super::dml::InsertOp; +use super::plan::ColumnUnnestList; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + exec_err, get_target_functional_dependencies, internal_err, not_impl_err, + plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, + FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; - -use super::plan::{ColumnUnnestList, ColumnUnnestType}; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -172,12 +174,45 @@ impl LogicalPlanBuilder { /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. /// + /// so it's usually better to override the default names with a table alias list. + /// + /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + pub fn values(values: Vec>) -> Result { + if values.is_empty() { + return plan_err!("Values list cannot be empty"); + } + let n_cols = values[0].len(); + if n_cols == 0 { + return plan_err!("Values list cannot be zero length"); + } + for (i, row) in values.iter().enumerate() { + if row.len() != n_cols { + return plan_err!( + "Inconsistent data length across values list: got {} values in row {} but expected {}", + row.len(), + i, + n_cols + ); + } + } + + // Infer from data itself + Self::infer_data(values) + } + + /// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming + /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) + /// documentation for more details. + /// /// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table. /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. /// /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. - pub fn values(mut values: Vec>) -> Result { + pub fn values_with_schema( + values: Vec>, + schema: &DFSchemaRef, + ) -> Result { if values.is_empty() { return plan_err!("Values list cannot be empty"); } @@ -196,19 +231,57 @@ impl LogicalPlanBuilder { } } - let empty_schema = DFSchema::empty(); + // Check the type of value against the schema + Self::infer_values_from_schema(values, schema) + } + + fn infer_values_from_schema( + values: Vec>, + schema: &DFSchema, + ) -> Result { + let n_cols = values[0].len(); + let mut field_types: Vec = Vec::with_capacity(n_cols); + for j in 0..n_cols { + let field_type = schema.field(j).data_type(); + for row in values.iter() { + let value = &row[j]; + let data_type = value.get_type(schema)?; + + if !data_type.equals_datatype(field_type) { + if can_cast_types(&data_type, field_type) { + } else { + return exec_err!( + "type mistmatch and can't cast to got {} and {}", + data_type, + field_type + ); + } + } + } + field_types.push(field_type.to_owned()); + } + + Self::infer_inner(values, &field_types, schema) + } + + fn infer_data(values: Vec>) -> Result { + let n_cols = values[0].len(); + let schema = DFSchema::empty(); + let mut field_types: Vec = Vec::with_capacity(n_cols); for j in 0..n_cols { let mut common_type: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; - let data_type = value.get_type(&empty_schema)?; + let data_type = value.get_type(&schema)?; if data_type == DataType::Null { continue; } + if let Some(prev_type) = common_type { // get common type of each column values. - let Some(new_type) = values_coercion(&data_type, &prev_type) else { + let data_types = vec![prev_type.clone(), data_type.clone()]; + let Some(new_type) = type_union_resolution(&data_types) else { return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"); }; common_type = Some(new_type); @@ -220,14 +293,22 @@ impl LogicalPlanBuilder { // since the code loop skips NULL field_types.push(common_type.unwrap_or(DataType::Null)); } + + Self::infer_inner(values, &field_types, &schema) + } + + fn infer_inner( + mut values: Vec>, + field_types: &[DataType], + schema: &DFSchema, + ) -> Result { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in field_types.iter().enumerate() { if let Expr::Literal(ScalarValue::Null) = row[j] { row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); } else { - row[j] = - std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?; + row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } } } @@ -242,6 +323,7 @@ impl LogicalPlanBuilder { .collect::>(); let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); + Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) } @@ -307,20 +389,14 @@ impl LogicalPlanBuilder { input: LogicalPlan, table_name: impl Into, table_schema: &Schema, - overwrite: bool, + insert_op: InsertOp, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto - }; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), table_schema, - op, + WriteOp::Insert(insert_op), Arc::new(input), )))) } @@ -436,9 +512,22 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { + let skip_expr = if skip == 0 { + None + } else { + Some(lit(skip as i64)) + }; + let fetch_expr = fetch.map(|f| lit(f as i64)); + self.limit_by_expr(skip_expr, fetch_expr) + } + + /// Limit the number of rows returned + /// + /// Similar to `limit` but uses expressions for `skip` and `fetch` + pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { Ok(Self::new(LogicalPlan::Limit(Limit { - skip, - fetch, + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), input: self.plan, }))) } @@ -954,9 +1043,14 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin { + Ok(Self::new(LogicalPlan::Join(Join { left: self.plan, right: Arc::new(right), + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, schema: DFSchemaRef::new(join_schema), }))) } @@ -1185,7 +1279,7 @@ impl LogicalPlanBuilder { ) -> Result { unnest_with_options( Arc::unwrap_or_clone(self.plan), - vec![(column.into(), ColumnUnnestType::Inferred)], + vec![column.into()], options, ) .map(Self::new) @@ -1196,26 +1290,6 @@ impl LogicalPlanBuilder { self, columns: Vec, options: UnnestOptions, - ) -> Result { - unnest_with_options( - Arc::unwrap_or_clone(self.plan), - columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(), - options, - ) - .map(Self::new) - } - - /// Unnest the given columns with the given [`UnnestOptions`] - /// if one column is a list type, it can be recursively and simultaneously - /// unnested into the desired recursion levels - /// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2) - pub fn unnest_columns_recursive_with_options( - self, - columns: Vec<(Column, ColumnUnnestType)>, - options: UnnestOptions, ) -> Result { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) @@ -1328,8 +1402,12 @@ pub fn build_join_schema( join_type, left.fields().len(), ); - let mut metadata = left.metadata().clone(); - metadata.extend(right.metadata().clone()); + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } @@ -1404,9 +1482,23 @@ pub fn validate_unique_names<'a>( /// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union /// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { + if left_plan.schema().fields().len() != right_plan.schema().fields().len() { + return plan_err!( + "UNION queries have different number of columns: \ + left has {} columns whereas right has {} columns", + left_plan.schema().fields().len(), + right_plan.schema().fields().len() + ); + } + // Temporarily use the schema from the left input and later rely on the analyzer to // coerce the two schemas into a common one. - let schema = Arc::clone(left_plan.schema()); + + // Functional Dependencies doesn't preserve after UNION operation + let schema = (**left_plan.schema()).clone(); + let schema = + Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); + Ok(LogicalPlan::Union(Union { inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], schema, @@ -1586,21 +1678,19 @@ impl TableSource for LogicalTableSource { fn supports_filters_pushdown( &self, filters: &[&Expr], - ) -> Result> { + ) -> Result> { Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } } /// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { - let unnestings = columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(); - unnest_with_options(input, unnestings, UnnestOptions::default()) + unnest_with_options(input, columns, UnnestOptions::default()) } -pub fn get_unnested_list_datatype_recursive( +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( data_type: &DataType, depth: usize, ) -> Result { @@ -1619,27 +1709,6 @@ pub fn get_unnested_list_datatype_recursive( internal_err!("trying to unnest on invalid data type {:?}", data_type) } -/// Infer the unnest type based on the data type: -/// - list type: infer to unnest(list(col, depth=1)) -/// - struct type: infer to unnest(struct) -fn infer_unnest_type( - col_name: &String, - data_type: &DataType, -) -> Result { - match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - Ok(ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(col_name), - depth: 1, - }])) - } - DataType::Struct(_) => Ok(ColumnUnnestType::Struct), - _ => { - internal_err!("trying to unnest on invalid data type {:?}", data_type) - } - } -} - pub fn get_struct_unnested_columns( col_name: &String, inner_fields: &Fields, @@ -1728,20 +1797,15 @@ pub fn get_unnested_columns( /// ``` pub fn unnest_with_options( input: LogicalPlan, - columns_to_unnest: Vec<(Column, ColumnUnnestType)>, + columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; let mut struct_columns = vec![]; let indices_to_unnest = columns_to_unnest .iter() - .map(|col_unnesting| { - Ok(( - input.schema().index_of_column(&col_unnesting.0)?, - col_unnesting, - )) - }) - .collect::>>()?; + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; let input_schema = input.schema(); @@ -1766,51 +1830,59 @@ pub fn unnest_with_options( .enumerate() .map(|(index, (original_qualifier, original_field))| { match indices_to_unnest.get(&index) { - Some((column_to_unnest, unnest_type)) => { - let mut inferred_unnest_type = unnest_type.clone(); - if let ColumnUnnestType::Inferred = unnest_type { - inferred_unnest_type = infer_unnest_type( + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( &column_to_unnest.name, original_field.data_type(), + 1, )?; - } - let transformed_columns: Vec<(Column, Arc)> = - match inferred_unnest_type { - ColumnUnnestType::Struct => { + match original_field.data_type() { + DataType::Struct(_) => { struct_columns.push(index); - get_unnested_columns( - &column_to_unnest.name, - original_field.data_type(), - 1, - )? } - ColumnUnnestType::List(unnest_lists) => { - list_columns.extend( - unnest_lists - .iter() - .map(|ul| (index, ul.to_owned().clone())), - ); - unnest_lists - .iter() - .map( - |ColumnUnnestList { - output_column, - depth, - }| { - get_unnested_columns( - &output_column.name, - original_field.data_type(), - *depth, - ) - }, - ) - .collect::)>>>>()? - .into_iter() - .flatten() - .collect::>() + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); } - _ => return internal_err!("Invalid unnest type"), + _ => {} }; + } + // new columns dependent on the same original index dependency_indices .extend(std::iter::repeat(index).take(transformed_columns.len())); @@ -1859,7 +1931,7 @@ mod tests { use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::SchemaError; + use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2267,24 +2339,19 @@ mod tests { // Simultaneously unnesting a list (with different depth) and a struct column let plan = nested_table_scan("test_table")? - .unnest_columns_recursive_with_options( - vec![ - ( - "stringss".into(), - ColumnUnnestType::List(vec![ - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_1"), - depth: 1, - }, - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_2"), - depth: 2, - }, - ]), - ), - ("struct_singular".into(), ColumnUnnestType::Inferred), - ], - UnnestOptions::default(), + .unnest_columns_with_options( + vec!["stringss".into(), "struct_singular".into()], + UnnestOptions::default() + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_1".into(), + depth: 1, + }) + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_2".into(), + depth: 2, + }), )? .build()?; diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 9aaa5c98037ac..93e8b5fd045e7 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -120,7 +120,7 @@ impl DdlStatement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a DdlStatement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -202,6 +202,8 @@ pub struct CreateExternalTable { pub table_partition_cols: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user @@ -298,6 +300,8 @@ pub struct CreateMemoryTable { pub or_replace: bool, /// Default values for columns pub column_defaults: Vec<(String, Expr)>, + /// Wheter the table is `TableType::Temporary` + pub temporary: bool, } /// Creates a view. @@ -311,6 +315,8 @@ pub struct CreateView { pub or_replace: bool, /// SQL used to create the view, if available pub definition: Option, + /// Wheter the view is ephemeral + pub temporary: bool, } /// Creates a catalog (aka "Database"). diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 26d54803d4036..c0549451a7763 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -504,11 +504,6 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::CrossJoin(_) => { - json!({ - "Node Type": "Cross Join" - }) - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -549,11 +544,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let mut object = serde_json::json!( { "Node Type": "Limit", - "Skip": skip, } ); + if let Some(s) = skip { + object["Skip"] = s.to_string().into() + }; if let Some(f) = fetch { - object["Fetch"] = serde_json::Value::Number((*f).into()); + object["Fetch"] = f.to_string().into() }; object } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index c2ed9dc0781cc..669bc8e8a7d34 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -146,8 +146,7 @@ impl PartialOrd for DmlStatement { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { - InsertOverwrite, - InsertInto, + Insert(InsertOp), Delete, Update, Ctas, @@ -157,8 +156,7 @@ impl WriteOp { /// Return a descriptive name of this [`WriteOp`] pub fn name(&self) -> &str { match self { - WriteOp::InsertOverwrite => "Insert Overwrite", - WriteOp::InsertInto => "Insert Into", + WriteOp::Insert(insert) => insert.name(), WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", @@ -167,7 +165,38 @@ impl WriteOp { } impl Display for WriteOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum InsertOp { + /// Appends new rows to the existing table without modifying any + /// existing rows. This corresponds to the SQL `INSERT INTO` query. + Append, + /// Overwrites all existing rows in the table with the new rows. + /// This corresponds to the SQL `INSERT OVERWRITE` query. + Overwrite, + /// If any existing rows collides with the inserted rows (typically based + /// on a unique key or primary key), those existing rows are replaced. + /// This corresponds to the SQL `REPLACE INTO` query and its equivalents. + Replace, +} + +impl InsertOp { + /// Return a descriptive name of this [`InsertOp`] + pub fn name(&self) -> &str { + match self { + InsertOp::Append => "Insert Into", + InsertOp::Overwrite => "Insert Overwrite", + InsertOp::Replace => "Replace Into", + } + } +} + +impl Display for InsertOp { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index d49c85fb6fd69..19d4cb3db9ce5 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -195,6 +195,16 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// directly because it must remain object safe. fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool; fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option; + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false + } } impl Hash for dyn UserDefinedLogicalNode { @@ -295,6 +305,16 @@ pub trait UserDefinedLogicalNodeCore: ) -> Option>> { None } + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -361,6 +381,10 @@ impl UserDefinedLogicalNode for T { .downcast_ref::() .and_then(|other| self.partial_cmp(other)) } + + fn supports_limit_pushdown(&self) -> bool { + self.supports_limit_pushdown() + } } fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet { diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a189d4635e001..80a8962124428 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -35,10 +35,10 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, ColumnUnnestType, CrossJoin, - DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, + projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, + Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 443d23804adb2..a301c48659d7c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::dml::CopyTo; use super::DdlStatement; @@ -49,8 +49,10 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, + UnnestOptions, }; +use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; @@ -219,9 +221,6 @@ pub enum LogicalPlan { /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` Join(Join), - /// Apply Cross Join to two logical plans. - /// This is used to implement SQL `CROSS JOIN` - CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems @@ -309,7 +308,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::CrossJoin(CrossJoin { schema, .. }) => schema, LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::Statement(statement) => statement.schema(), @@ -342,8 +340,7 @@ impl LogicalPlan { | LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => self + | LogicalPlan::Join(_) => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -421,27 +418,6 @@ impl LogicalPlan { exprs } - #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> - where - F: FnMut(&Expr) -> Result<(), E>, - { - let mut err = Ok(()); - self.apply_expressions(|e| { - if let Err(e) = f(e) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err - } - /// Returns all inputs / children of this `LogicalPlan` node. /// /// Note does not include inputs to inputs, or subqueries. @@ -454,7 +430,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => vec![left, right], LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -560,13 +535,6 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::CrossJoin(cross) => { - if cross.left.schema().fields().is_empty() { - cross.right.head_output_expr() - } else { - cross.left.head_output_expr() - } - } LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() } @@ -692,20 +660,6 @@ impl LogicalPlan { null_equals_null, })) } - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: _, - }) => { - let join_schema = - build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - - Ok(LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: join_schema.into(), - })) - } LogicalPlan::Subquery(_) => Ok(self), LogicalPlan::SubqueryAlias(SubqueryAlias { input, @@ -956,11 +910,6 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::CrossJoin(_) => { - self.assert_no_expressions(expr)?; - let (left, right) = self.only_two_inputs(inputs)?; - LogicalPlanBuilder::from(left).cross_join(right)?.build() - } LogicalPlan::Subquery(Subquery { outer_ref_columns, .. }) => { @@ -979,11 +928,20 @@ impl LogicalPlan { .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { - self.assert_no_expressions(expr)?; + let old_expr_len = skip.iter().chain(fetch.iter()).count(); + if old_expr_len != expr.len() { + return internal_err!( + "Invalid number of new Limit expressions: expected {}, got {}", + old_expr_len, + expr.len() + ); + } + let new_skip = skip.as_ref().and_then(|_| expr.pop()); + let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, + skip: new_skip.map(Box::new), + fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } @@ -992,6 +950,7 @@ impl LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1004,6 +963,7 @@ impl LogicalPlan { if_not_exists: *if_not_exists, or_replace: *or_replace, column_defaults: column_defaults.clone(), + temporary: *temporary, }, ))) } @@ -1011,6 +971,7 @@ impl LogicalPlan { name, or_replace, definition, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1019,6 +980,7 @@ impl LogicalPlan { input: Arc::new(input), name: name.clone(), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone(), }))) } @@ -1331,12 +1293,6 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(), JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - match (left.max_rows(), right.max_rows()) { - (Some(left_max), Some(right_max)) => Some(left_max * right_max), - _ => None, - } - } LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => inputs .iter() @@ -1354,7 +1310,10 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, + LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + Ok(FetchType::Literal(s)) => s, + _ => None, + }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1868,6 +1827,11 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); + let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + "Cross".to_string() + } else { + join_type.to_string() + }; match join_constraint { JoinConstraint::On => { write!( @@ -1889,9 +1853,6 @@ impl LogicalPlan { } } } - LogicalPlan::CrossJoin(_) => { - write!(f, "CrossJoin:") - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -1919,16 +1880,20 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(limit) => { + // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. + let skip_str = match limit.get_skip_type() { + Ok(SkipType::Literal(n)) => n.to_string(), + _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + }; + let fetch_str = match limit.get_fetch_type() { + Ok(FetchType::Literal(Some(n))) => n.to_string(), + Ok(FetchType::Literal(None)) => "None".to_string(), + _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + }; write!( f, - "Limit: skip={}, fetch={}", - skip, - fetch.map_or_else(|| "None".to_string(), |x| x.to_string()) + "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2593,28 +2558,7 @@ impl TableScan { } } -/// Apply Cross Join to two logical plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct CrossJoin { - /// Left input - pub left: Arc, - /// Right input - pub right: Arc, - /// The output schema, containing fields from the left and right inputs - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for CrossJoin { - fn partial_cmp(&self, other: &Self) -> Option { - match self.left.partial_cmp(&other.left) { - Some(Ordering::Equal) => self.right.partial_cmp(&other.right), - cmp => cmp, - } - } -} - -/// Repartition the plan based on a partitioning scheme. +// Repartition the plan based on a partitioning scheme. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Repartition { /// The incoming logical plan @@ -2788,14 +2732,71 @@ impl PartialOrd for Extension { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { /// Number of rows to skip before fetch - pub skip: usize, + pub skip: Option>, /// Maximum number of rows to fetch, /// None means fetching all rows - pub fetch: Option, + pub fetch: Option>, /// The logical plan pub input: Arc, } +/// Different types of skip expression in Limit plan. +pub enum SkipType { + /// The skip expression is a literal value. + Literal(usize), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +/// Different types of fetch expression in Limit plan. +pub enum FetchType { + /// The fetch expression is a literal value. + /// `Literal(None)` means the fetch expression is not provided. + Literal(Option), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +impl Limit { + /// Get the skip type from the limit plan. + pub fn get_skip_type(&self) -> Result { + match self.skip.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(s)) => { + // `skip = NULL` is equivalent to `skip = 0` + let s = s.unwrap_or(0); + if s >= 0 { + Ok(SkipType::Literal(s as usize)) + } else { + plan_err!("OFFSET must be >=0, '{}' was provided", s) + } + } + _ => Ok(SkipType::UnsupportedExpr), + }, + // `skip = None` is equivalent to `skip = 0` + None => Ok(SkipType::Literal(0)), + } + } + + /// Get the fetch type from the limit plan. + pub fn get_fetch_type(&self) -> Result { + match self.fetch.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => { + if s >= 0 { + Ok(FetchType::Literal(Some(s as usize))) + } else { + plan_err!("LIMIT must be >= 0, '{}' was provided", s) + } + } + Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + _ => Ok(FetchType::UnsupportedExpr), + }, + None => Ok(FetchType::Literal(None)), + } + } +} + /// Removes duplicate rows from the input #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Distinct { @@ -2964,6 +2965,15 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + qualified_fields.push(( + None, + Field::new( + Self::INTERNAL_GROUPING_ID, + Self::grouping_id_type(qualified_fields.len()), + false, + ) + .into(), + )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); @@ -3015,9 +3025,19 @@ impl Aggregate { }) } + fn is_grouping_set(&self) -> bool { + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + } + /// Get the output expressions. fn output_expressions(&self) -> Result> { + static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; + if self.is_grouping_set() { + exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { + Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + })); + } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); Ok(exprs) @@ -3029,6 +3049,41 @@ impl Aggregate { pub fn group_expr_len(&self) -> Result { grouping_set_expr_count(&self.group_expr) } + + /// Returns the data type of the grouping id. + /// The grouping ID value is a bitmask where each set bit + /// indicates that the corresponding grouping expression is + /// null + pub fn grouping_id_type(group_exprs: usize) -> DataType { + if group_exprs <= 8 { + DataType::UInt8 + } else if group_exprs <= 16 { + DataType::UInt16 + } else if group_exprs <= 32 { + DataType::UInt32 + } else { + DataType::UInt64 + } + } + + /// Internal column used when the aggregation is a grouping set. + /// + /// This column contains a bitmask where each bit represents a grouping + /// expression. The least significant bit corresponds to the rightmost + /// grouping expression. A bit value of 0 indicates that the corresponding + /// column is included in the grouping set, while a value of 1 means it is excluded. + /// + /// For example, for the grouping expressions CUBE(a, b), the grouping ID + /// column will have the following values: + /// 0b00: Both `a` and `b` are included + /// 0b01: `b` is excluded + /// 0b10: `a` is excluded + /// 0b11: Both `a` and `b` are excluded + /// + /// This internal column is necessary because excluded columns are replaced + /// with `NULL` values. To handle these cases correctly, we must distinguish + /// between an actual `NULL` value in a column and a column being excluded from the set. + pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3071,6 +3126,8 @@ fn calc_func_dependencies_for_aggregate( let group_by_expr_names = group_expr .iter() .map(|item| item.schema_name().to_string()) + .collect::>() + .into_iter() .collect::>(); let aggregate_func_dependencies = aggregate_functional_dependencies( input.schema(), @@ -3300,39 +3357,6 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents the unnesting operation on a column based on the context (a known struct -/// column, a list column, or let the planner infer the unnesting type). -/// -/// The inferred unnesting type works for both struct and list column, but the unnesting -/// will only be done once (depth = 1). In case recursion is needed on a multi-dimensional -/// list type, use [`ColumnUnnestList`] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] -pub enum ColumnUnnestType { - // Unnesting a list column, a vector of ColumnUnnestList is used because - // a column can be unnested at different levels, resulting different output columns - List(Vec), - // for struct, there can only be one unnest performed on one column at a time - Struct, - // Infer the unnest type based on column schema - // If column is a list column, the unnest depth will be 1 - // This value is to support sugar syntax of old api in Dataframe (unnest(either_list_or_struct_column)) - Inferred, -} - -impl fmt::Display for ColumnUnnestType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ColumnUnnestType::List(lists) => { - let list_strs: Vec = - lists.iter().map(|list| list.to_string()).collect(); - write!(f, "List([{}])", list_strs.join(", ")) - } - ColumnUnnestType::Struct => write!(f, "Struct"), - ColumnUnnestType::Inferred => write!(f, "Inferred"), - } - } -} - /// Represent the unnesting operation on a list column, such as the recursion depth and /// the output column name after unnesting /// @@ -3358,8 +3382,8 @@ pub struct ColumnUnnestList { pub depth: usize, } -impl fmt::Display for ColumnUnnestList { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Display for ColumnUnnestList { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}|depth={}", self.output_column, self.depth) } } @@ -3371,7 +3395,7 @@ pub struct Unnest { /// The incoming logical plan pub input: Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: Vec<(Column, ColumnUnnestType)>, + pub exec_columns: Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: Vec<(usize, ColumnUnnestList)>, @@ -3395,7 +3419,7 @@ impl PartialOrd for Unnest { /// The incoming logical plan pub input: &'a Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: &'a Vec<(Column, ColumnUnnestType)>, + pub exec_columns: &'a Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>, @@ -4076,4 +4100,25 @@ digraph { ); assert_eq!(describe_table.partial_cmp(&describe_table_clone), None); } + + #[test] + fn test_limit_with_new_children() { + let limit = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(Expr::Literal( + ScalarValue::new_ten(&DataType::UInt32).unwrap(), + ))), + input: Arc::new(LogicalPlan::Values(Values { + schema: Arc::new(DFSchema::empty()), + values: vec![vec![]], + })), + }); + let new_limit = limit + .with_new_exprs( + limit.expressions(), + limit.inputs().into_iter().cloned().collect(), + ) + .unwrap(); + assert_eq!(limit, new_limit); + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index ed06375157c94..7ad18ce7bbf77 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -61,7 +61,7 @@ impl Statement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a Statement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 8ba68697bd4d7..0658f7029740f 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,12 +37,13 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ - dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, - DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, - Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, - UserDefinedLogicalNode, Values, Window, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, + Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, + LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort, + Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, + Window, }; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Exists, InSubquery}; @@ -159,22 +160,6 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) - }), LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { @@ -285,6 +270,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -293,6 +279,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) }), DdlStatement::CreateView(CreateView { @@ -300,12 +287,14 @@ impl TreeNode for LogicalPlan { input, or_replace, definition, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, or_replace, definition, + temporary, }) }), // no inputs in these statements @@ -497,7 +486,7 @@ impl LogicalPlan { let exprs = columns .iter() - .map(|(c, _)| Expr::Column(c.clone())) + .map(|c| Expr::Column(c.clone())) .collect::>(); exprs.iter().apply_until_stop(f) } @@ -511,14 +500,17 @@ impl LogicalPlan { .chain(select_expr.iter()) .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) .apply_until_stop(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip + .iter() + .chain(fetch.iter()) + .map(|e| e.deref()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -722,15 +714,33 @@ impl LogicalPlan { schema, })) }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + let skip = skip.map(|e| *e); + let fetch = fetch.map(|e| *e); + map_until_stop_and_collect!( + skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }), + fetch, + fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), + input, + }) + }) + } // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 24f589c41582c..7dd7360e478f2 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -17,6 +17,7 @@ //! [`ContextProvider`] and [`ExprPlanner`] APIs to customize SQL query planning +use std::fmt::Debug; use std::sync::Arc; use arrow::datatypes::{DataType, Field, SchemaRef}; @@ -88,7 +89,7 @@ pub trait ContextProvider { } /// This trait allows users to customize the behavior of the SQL planner -pub trait ExprPlanner: Send + Sync { +pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible fn plan_binary_op( diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 988dc0f5aeda5..6d3457f70d4c7 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -21,8 +21,9 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; -use std::collections::HashMap; -use std::{collections::HashSet, sync::Arc}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::sync::Arc; /// A registry knows how to build logical expressions out of user-defined function' names pub trait FunctionRegistry { @@ -123,7 +124,7 @@ pub trait FunctionRegistry { } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. -pub trait SerializerRegistry: Send + Sync { +pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index a55cb49b1f402..e636fabf10fb7 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -29,10 +29,10 @@ use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; /// information in without having to create `DFSchema` objects. If you /// have a [`DFSchemaRef`] you can use [`SimplifyContext`] pub trait SimplifyInfo { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result; - /// returns true of this expr is nullable (could possibly be NULL) + /// Returns true of this expr is nullable (could possibly be NULL) fn nullable(&self, expr: &Expr) -> Result; /// Returns details needed for partial expression evaluation @@ -72,7 +72,7 @@ impl<'a> SimplifyContext<'a> { } impl<'a> SimplifyInfo for SimplifyContext<'a> { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { if let Some(schema) = &self.schema { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -113,7 +113,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), - /// the function call could not be simplified, and the arguments + /// The function call could not be simplified, and the arguments /// are return unmodified. Original(Vec), } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b4f768085fcc3..262aa99e50075 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -34,7 +34,6 @@ use crate::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::AggregateOrderSensitivity, Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, - Volatility, }; macro_rules! create_func { @@ -106,7 +105,7 @@ pub struct Sum { impl Sum { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::user_defined(Immutable), } } } @@ -236,13 +235,13 @@ impl Count { pub fn new() -> Self { Self { aliases: vec!["count".to_string()], - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Count { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -318,13 +317,13 @@ impl Default for Min { impl Min { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Min { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -403,13 +402,13 @@ impl Default for Max { impl Max { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Max { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index c7c498dd3f017..90afe5722abbc 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical expr +//! Tree node implementation for Logical Expressions use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, @@ -28,7 +28,16 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{map_until_stop_and_collect, Result}; +/// Implementation of the [`TreeNode`] trait +/// +/// This allows logical expressions (`Expr`) to be traversed and transformed +/// Facilitates tasks such as optimization and rewriting during query +/// planning. impl TreeNode for Expr { + /// Applies a function `f` to each child expression of `self`. + /// + /// The function `f` determines whether to continue traversing the tree or to stop. + /// This method collects all child expressions and applies `f` to each. fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, @@ -122,6 +131,10 @@ impl TreeNode for Expr { children.into_iter().apply_until_stop(f) } + /// Maps each child of `self` using the provided closure `f`. + /// + /// The closure `f` takes ownership of an expression and returns a `Transformed` result, + /// indicating whether the expression was transformed or left unchanged. fn map_children Result>>( self, mut f: F, @@ -346,6 +359,7 @@ impl TreeNode for Expr { } } +/// Transforms a boxed expression by applying the provided closure `f`. fn transform_box Result>>( be: Box, f: &mut F, @@ -353,6 +367,7 @@ fn transform_box Result>>( Ok(f(*be)?.update_data(Box::new)) } +/// Transforms an optional boxed expression by applying the provided closure `f`. fn transform_option_box Result>>( obe: Option>, f: &mut F, @@ -380,6 +395,7 @@ fn transform_vec Result>>( ve.into_iter().map_until_stop_and_collect(f) } +/// Transforms an optional vector of sort expressions by applying the provided closure `f`. pub fn transform_sort_option_vec Result>>( sorts_option: Option>, f: &mut F, @@ -389,6 +405,7 @@ pub fn transform_sort_option_vec Result>>( }) } +/// Transforms an vector of sort expressions by applying the provided closure `f`. pub fn transform_sort_vec Result>>( sorts: Vec, mut f: &mut F, diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d30d202df0505..85f8e20ba4a5c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -26,8 +26,9 @@ use datafusion_common::{ utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; -use datafusion_expr_common::signature::{ - ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, +use datafusion_expr_common::{ + signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, + type_coercion::binary::string_coercion, }; use std::sync::Arc; @@ -167,6 +168,21 @@ pub fn data_types( try_coerce_types(valid_types, current_types, &signature.type_signature) } +fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { + if let TypeSignature::OneOf(signatures) = type_signature { + return signatures.iter().all(is_well_supported_signature); + } + + matches!( + type_signature, + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::String(_) + | TypeSignature::Coercible(_) + | TypeSignature::Any(_) + ) +} + fn try_coerce_types( valid_types: Vec>, current_types: &[DataType], @@ -175,14 +191,7 @@ fn try_coerce_types( let mut valid_types = valid_types; // Well-supported signature that returns exact valid types. - if !valid_types.is_empty() - && matches!( - type_signature, - TypeSignature::UserDefined - | TypeSignature::Numeric(_) - | TypeSignature::Coercible(_) - ) - { + if !valid_types.is_empty() && is_well_supported_signature(type_signature) { // exact valid types assert_eq!(valid_types.len(), 1); let valid_types = valid_types.swap_remove(0); @@ -212,20 +221,37 @@ fn get_valid_types_with_scalar_udf( current_types: &[DataType], func: &ScalarUDF, ) -> Result>> { - let valid_types = match signature { + match signature { TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + Ok(coerced_types) => Ok(vec![coerced_types]), + Err(e) => exec_err!("User-defined coercion failed with {:?}", e), }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) - .flatten() - .collect::>(), - _ => get_valid_types(signature, current_types)?, - }; + TypeSignature::OneOf(signatures) => { + let mut res = vec![]; + let mut errors = vec![]; + for sig in signatures { + match get_valid_types_with_scalar_udf(sig, current_types, func) { + Ok(valid_types) => { + res.extend(valid_types); + } + Err(e) => { + errors.push(e.to_string()); + } + } + } - Ok(valid_types) + // Every signature failed, return the joined error + if res.is_empty() { + internal_err!( + "Failed to match any signature, errors: {}", + errors.join(",") + ) + } else { + Ok(res) + } + } + _ => get_valid_types(signature, current_types), + } } fn get_valid_types_with_aggregate_udf( @@ -374,6 +400,67 @@ fn get_valid_types( .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::String(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + fn coercion_rule( + lhs_type: &DataType, + rhs_type: &DataType, + ) -> Result { + match (lhs_type, rhs_type) { + (DataType::Null, DataType::Null) => Ok(DataType::Utf8), + (DataType::Null, data_type) | (data_type, DataType::Null) => { + coercion_rule(data_type, &DataType::Utf8) + } + (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => { + coercion_rule(lhs, rhs) + } + (DataType::Dictionary(_, v), other) + | (other, DataType::Dictionary(_, v)) => coercion_rule(v, other), + _ => { + if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) { + Ok(coerced_type) + } else { + plan_err!( + "{} and {} are not coercible to a common string type", + lhs_type, + rhs_type + ) + } + } + } + } + + // Length checked above, safe to unwrap + let mut coerced_type = current_types.first().unwrap().to_owned(); + for t in current_types.iter().skip(1) { + coerced_type = coercion_rule(&coerced_type, t)?; + } + + fn base_type_or_default_type(data_type: &DataType) -> DataType { + if data_type.is_null() { + DataType::Utf8 + } else if let DataType::Dictionary(_, v) = data_type { + base_type_or_default_type(v) + } else { + data_type.to_owned() + } + } + + vec![vec![base_type_or_default_type(&coerced_type); *number]] + } TypeSignature::Numeric(number) => { if *number < 1 { return plan_err!( @@ -602,89 +689,48 @@ fn coerced_from<'a>( Some(type_into.clone()) } // coerced into type_into - (Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()), - (Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => { - Some(type_into.clone()) - } - (Int32, _) - if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => - { - Some(type_into.clone()) - } - (Int64, _) - if matches!( - type_from, - Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 - ) => - { - Some(type_into.clone()) - } - (UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()), - (UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => { - Some(type_into.clone()) - } - (UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => { - Some(type_into.clone()) - } - (UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => { - Some(type_into.clone()) - } - (Float32, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - ) => - { - Some(type_into.clone()) - } - (Float64, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - ) => - { - Some(type_into.clone()) - } - (Timestamp(TimeUnit::Nanosecond, None), _) - if matches!( - type_from, - Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8 - ) => - { - Some(type_into.clone()) - } - (Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => { + (Int8, Null | Int8) => Some(type_into.clone()), + (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()), + (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()), + (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => { Some(type_into.clone()) } + (UInt8, Null | UInt8) => Some(type_into.clone()), + (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), + (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), + (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + ( + Float32, + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + | Float32, + ) => Some(type_into.clone()), + ( + Float64, + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _), + ) => Some(type_into.clone()), + ( + Timestamp(TimeUnit::Nanosecond, None), + Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, + ) => Some(type_into.clone()), + (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), // We can go into a Utf8View from a Utf8 or LargeUtf8 - (Utf8View, _) if matches!(type_from, Utf8 | LargeUtf8 | Null) => { - Some(type_into.clone()) - } + (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), // Any type can be coerced into strings (Utf8 | LargeUtf8, _) => Some(type_into.clone()), (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), - (List(_), _) if matches!(type_from, FixedSizeList(_, _)) => { - Some(type_into.clone()) - } + (List(_), FixedSizeList(_, _)) => Some(type_into.clone()), // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this @@ -695,18 +741,16 @@ fn coerced_from<'a>( Some(type_into.clone()) } // should be able to coerce wildcard fixed size list to non wildcard fixed size list - (FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from { - FixedSizeList(f_from, size_from) => { - match coerced_from(f_into.data_type(), f_from.data_type()) { - Some(data_type) if &data_type != f_into.data_type() => { - let new_field = - Arc::new(f_into.as_ref().clone().with_data_type(data_type)); - Some(FixedSizeList(new_field, *size_from)) - } - Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), - _ => None, - } + ( + FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), + FixedSizeList(f_from, size_from), + ) => match coerced_from(f_into.data_type(), f_from.data_type()) { + Some(data_type) if &data_type != f_into.data_type() => { + let new_field = + Arc::new(f_into.as_ref().clone().with_data_type(data_type)); + Some(FixedSizeList(new_field, *size_from)) } + Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), _ => None, }, (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { @@ -721,12 +765,7 @@ fn coerced_from<'a>( _ => None, } } - (Timestamp(_, Some(_)), _) - if matches!( - type_from, - Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 - ) => - { + (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { Some(type_into.clone()) } _ => None, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e3ef672daf5ff..dbbf88447ba39 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -26,7 +26,8 @@ use std::vec; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expr::AggregateFunction; use crate::function::{ @@ -35,8 +36,8 @@ use crate::function::{ use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::Signature; use crate::{Accumulator, Expr}; +use crate::{Documentation, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -94,6 +95,22 @@ impl fmt::Display for AggregateUDF { } } +/// Arguments passed to [`AggregateUDFImpl::value_from_stats`] +pub struct StatisticsArgs<'a> { + /// The statistics of the aggregate input + pub statistics: &'a Statistics, + /// The resolved return type of the aggregate function + pub return_type: &'a DataType, + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + impl AggregateUDF { /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object /// @@ -123,7 +140,7 @@ impl AggregateUDF { )) } - /// creates an [`Expr`] that calls the aggregate function. + /// Creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. @@ -237,17 +254,35 @@ impl AggregateUDF { } /// Returns true if the function is max, false if the function is min - /// None in all other cases, used in certain optimizations or + /// None in all other cases, used in certain optimizations for /// or aggregate - /// pub fn is_descending(&self) -> Option { self.inner.is_descending() } + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// See [`AggregateUDFImpl::value_from_stats`] for more details. + pub fn value_from_stats( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + self.inner.value_from_stats(statistics_args) + } + /// See [`AggregateUDFImpl::default_value`] for more details. pub fn default_value(&self, data_type: &DataType) -> Result { self.inner.default_value(data_type) } + + /// Returns the documentation for this Aggregate UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } impl From for AggregateUDF @@ -272,25 +307,42 @@ where /// # Basic Example /// ``` /// # use std::any::Any; +/// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE; /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; +/// /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { -/// signature: Signature +/// signature: Signature, /// } /// /// impl GeoMeanUdf { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), /// } /// } /// } /// +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_AGGREGATE) +/// .with_description("calculates a geometric mean") +/// .with_syntax_example("geo_mean(2.0)") +/// .with_argument("arg1", "The Float64 number for the geometric mean") +/// .build() +/// .unwrap() +/// }) +/// } +/// /// /// Implement the AggregateUDFImpl trait for GeoMeanUdf /// impl AggregateUDFImpl for GeoMeanUdf { /// fn as_any(&self) -> &dyn Any { self } @@ -298,7 +350,7 @@ where /// fn signature(&self) -> &Signature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { /// if !matches!(args.get(0), Some(&DataType::Float64)) { -/// return plan_err!("add_one only accepts Float64 arguments"); +/// return plan_err!("geo_mean only accepts Float64 arguments"); /// } /// Ok(DataType::Float64) /// } @@ -310,6 +362,9 @@ where /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new AggregateUDF from the implementation @@ -548,8 +603,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { } /// If this function is max, return true - /// if the function is min, return false - /// otherwise return None (the default) + /// If the function is min, return false + /// Otherwise return None (the default) /// /// /// Note: this is used to use special aggregate implementations in certain conditions @@ -557,6 +612,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { None } + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// Using a [`ScalarValue`] rather than a runtime computation can significantly + /// improving query performance. + /// + /// For example, if the minimum value of column `x` is known to be `42` from + /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))` + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + None + } + /// Returns default value of the function given the input is all `null`. /// /// Most of the aggregate function return Null if input is Null, @@ -564,6 +631,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn default_value(&self, data_type: &DataType) -> Result { ScalarValue::try_from(data_type) } + + /// Returns the documentation for this Aggregate UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } } impl PartialEq for dyn AggregateUDFImpl { @@ -572,7 +647,7 @@ impl PartialEq for dyn AggregateUDFImpl { } } -// manual implementation of `PartialOrd` +// Manual implementation of `PartialOrd` // There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl // https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 impl PartialOrd for dyn AggregateUDFImpl { @@ -710,6 +785,41 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn is_descending(&self) -> Option { self.inner.is_descending() } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +// Aggregate UDF doc sections for use in public documentation +pub mod aggregate_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_GENERAL, + DOC_SECTION_STATISTICAL, + DOC_SECTION_APPROXIMATE, + ] + } + + pub const DOC_SECTION_GENERAL: DocSection = DocSection { + include: true, + label: "General Functions", + description: None, + }; + + pub const DOC_SECTION_STATISTICAL: DocSection = DocSection { + include: true, + label: "Statistical Functions", + description: None, + }; + + pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection { + include: true, + label: "Approximate Functions", + description: None, + }; } #[cfg(test)] diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 938e1181d85d4..83563603f2f3b 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -20,7 +20,9 @@ use crate::expr::schema_name_from_exprs_comma_seperated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; -use crate::{ColumnarValue, Expr, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, +}; use arrow::datatypes::DataType; use datafusion_common::{not_impl_err, ExprSchema, Result}; use datafusion_expr_common::interval_arithmetic::Interval; @@ -199,6 +201,17 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Invoke the function with `args` and number of rows, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_batch`] for more details. + pub fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + self.inner.invoke_batch(args, number_rows) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_no_args`] for more details. @@ -274,6 +287,14 @@ impl ScalarUDF { pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } + + /// Returns the documentation for this Scalar UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } impl From for ScalarUDF @@ -298,22 +319,39 @@ where /// # Basic Example /// ``` /// # use std::any::Any; +/// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +/// /// #[derive(Debug)] /// struct AddOne { -/// signature: Signature +/// signature: Signature, /// } /// /// impl AddOne { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), /// } /// } /// } +/// +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_MATH) +/// .with_description("Add one to an int32") +/// .with_syntax_example("add_one(2)") +/// .with_argument("arg1", "The int32 number to add one to") +/// .build() +/// .unwrap() +/// }) +/// } /// /// /// Implement the ScalarUDFImpl trait for AddOne /// impl ScalarUDFImpl for AddOne { @@ -328,6 +366,9 @@ where /// } /// // The actual implementation would add one to the argument /// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new ScalarUDF from the implementation @@ -437,7 +478,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. /// /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args - fn invoke(&self, _args: &[ColumnarValue]) -> Result; + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!( + "Function {} does not implement invoke but called", + self.name() + ) + } + + /// Invoke the function with `args` and the number of rows, + /// returning the appropriate result. + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + match args.is_empty() { + true => self.invoke_no_args(number_rows), + false => self.invoke(args), + } + } /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. @@ -596,6 +655,14 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { self.signature().hash(hasher); hasher.finish() } + + /// Returns the documentation for this Scalar UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -709,4 +776,100 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.aliases.hash(hasher); hasher.finish() } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +// Scalar UDF doc sections for use in public documentation +pub mod scalar_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_MATH, + DOC_SECTION_CONDITIONAL, + DOC_SECTION_STRING, + DOC_SECTION_BINARY_STRING, + DOC_SECTION_REGEX, + DOC_SECTION_DATETIME, + DOC_SECTION_ARRAY, + DOC_SECTION_STRUCT, + DOC_SECTION_MAP, + DOC_SECTION_HASHING, + DOC_SECTION_OTHER, + ] + } + + pub const DOC_SECTION_MATH: DocSection = DocSection { + include: true, + label: "Math Functions", + description: None, + }; + + pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection { + include: true, + label: "Conditional Functions", + description: None, + }; + + pub const DOC_SECTION_STRING: DocSection = DocSection { + include: true, + label: "String Functions", + description: None, + }; + + pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection { + include: true, + label: "Binary String Functions", + description: None, + }; + + pub const DOC_SECTION_REGEX: DocSection = DocSection { + include: true, + label: "Regular Expression Functions", + description: Some( + r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) +(minus support for several features including look-around and backreferences). +The following regular expression functions are supported:"#, + ), + }; + + pub const DOC_SECTION_DATETIME: DocSection = DocSection { + include: true, + label: "Time and Date Functions", + description: None, + }; + + pub const DOC_SECTION_ARRAY: DocSection = DocSection { + include: true, + label: "Array Functions", + description: None, + }; + + pub const DOC_SECTION_STRUCT: DocSection = DocSection { + include: true, + label: "Struct Functions", + description: None, + }; + + pub const DOC_SECTION_MAP: DocSection = DocSection { + include: true, + label: "Map Functions", + description: None, + }; + + pub const DOC_SECTION_HASHING: DocSection = DocSection { + include: true, + label: "Hashing Functions", + description: None, + }; + + pub const DOC_SECTION_OTHER: DocSection = DocSection { + include: true, + label: "Other Functions", + description: None, + }; } diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs new file mode 100644 index 0000000000000..a124361e42a3d --- /dev/null +++ b/datafusion/expr/src/udf_docs.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::exec_err; +use datafusion_common::Result; + +/// Documentation for use by [`ScalarUDFImpl`](crate::ScalarUDFImpl), +/// [`AggregateUDFImpl`](crate::AggregateUDFImpl) and [`WindowUDFImpl`](crate::WindowUDFImpl) functions +/// that will be used to generate public documentation. +/// +/// The name of the udf will be pulled from the [`ScalarUDFImpl::name`](crate::ScalarUDFImpl::name), +/// [`AggregateUDFImpl::name`](crate::AggregateUDFImpl::name) or [`WindowUDFImpl::name`](crate::WindowUDFImpl::name) +/// function as appropriate. +/// +/// All strings in the documentation are required to be +/// in [markdown format](https://www.markdownguide.org/basic-syntax/). +/// +/// Currently, documentation only supports a single language +/// thus all text should be in English. +#[derive(Debug, Clone)] +pub struct Documentation { + /// The section in the documentation where the UDF will be documented + pub doc_section: DocSection, + /// The description for the UDF + pub description: String, + /// A brief example of the syntax. For example "ascii(str)" + pub syntax_example: String, + /// A sql example for the UDF, usually in the form of a sql prompt + /// query and output. It is strongly recommended to provide an + /// example for anything but the most basic UDF's + pub sql_example: Option, + /// Arguments for the UDF which will be displayed in array order. + /// Left member of a pair is the argument name, right is a + /// description for the argument + pub arguments: Option>, + /// A list of alternative syntax examples for a function + pub alternative_syntax: Option>, + /// Related functions if any. Values should match the related + /// udf's name exactly. Related udf's must be of the same + /// UDF type (scalar, aggregate or window) for proper linking to + /// occur + pub related_udfs: Option>, +} + +impl Documentation { + /// Returns a new [`DocumentationBuilder`] with no options set. + pub fn builder() -> DocumentationBuilder { + DocumentationBuilder::new() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DocSection { + /// True to include this doc section in the public + /// documentation, false otherwise + pub include: bool, + /// A display label for the doc section. For example: "Math Expressions" + pub label: &'static str, + /// An optional description for the doc section + pub description: Option<&'static str>, +} + +/// A builder to be used for building [`Documentation`]'s. +/// +/// Example: +/// +/// ```rust +/// # use datafusion_expr::Documentation; +/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +/// # use datafusion_common::Result; +/// # +/// # fn main() -> Result<()> { +/// let documentation = Documentation::builder() +/// .with_doc_section(DOC_SECTION_MATH) +/// .with_description("Add one to an int32") +/// .with_syntax_example("add_one(2)") +/// .with_argument("arg_1", "The int32 number to add one to") +/// .build()?; +/// Ok(()) +/// # } +pub struct DocumentationBuilder { + pub doc_section: Option, + pub description: Option, + pub syntax_example: Option, + pub sql_example: Option, + pub arguments: Option>, + pub alternative_syntax: Option>, + pub related_udfs: Option>, +} + +impl DocumentationBuilder { + pub fn new() -> Self { + Self { + doc_section: None, + description: None, + syntax_example: None, + sql_example: None, + arguments: None, + alternative_syntax: None, + related_udfs: None, + } + } + + pub fn with_doc_section(mut self, doc_section: DocSection) -> Self { + self.doc_section = Some(doc_section); + self + } + + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + pub fn with_syntax_example(mut self, syntax_example: impl Into) -> Self { + self.syntax_example = Some(syntax_example.into()); + self + } + + pub fn with_sql_example(mut self, sql_example: impl Into) -> Self { + self.sql_example = Some(sql_example.into()); + self + } + + /// Adds documentation for a specific argument to the documentation. + /// + /// Arguments are displayed in the order they are added. + pub fn with_argument( + mut self, + arg_name: impl Into, + arg_description: impl Into, + ) -> Self { + let mut args = self.arguments.unwrap_or_default(); + args.push((arg_name.into(), arg_description.into())); + self.arguments = Some(args); + self + } + + /// Add a standard "expression" argument to the documentation + /// + /// The argument is rendered like below if Some() is passed through: + /// + /// ```text + /// : + /// expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + /// + /// The argument is rendered like below if None is passed through: + /// + /// ```text + /// : + /// The expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + pub fn with_standard_argument( + self, + arg_name: impl Into, + expression_type: Option<&str>, + ) -> Self { + let description = format!( + "{} expression to operate on. Can be a constant, column, or function, and any combination of operators.", + expression_type.unwrap_or("The") + ); + self.with_argument(arg_name, description) + } + + pub fn with_alternative_syntax(mut self, syntax_name: impl Into) -> Self { + let mut alternative_syntax_array = self.alternative_syntax.unwrap_or_default(); + alternative_syntax_array.push(syntax_name.into()); + self.alternative_syntax = Some(alternative_syntax_array); + self + } + + pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { + let mut related = self.related_udfs.unwrap_or_default(); + related.push(related_udf.into()); + self.related_udfs = Some(related); + self + } + + pub fn build(self) -> Result { + let Self { + doc_section, + description, + syntax_example, + sql_example, + arguments, + alternative_syntax, + related_udfs, + } = self; + + if doc_section.is_none() { + return exec_err!("Documentation must have a doc section"); + } + if description.is_none() { + return exec_err!("Documentation must have a description"); + } + if syntax_example.is_none() { + return exec_err!("Documentation must have a syntax_example"); + } + + Ok(Documentation { + doc_section: doc_section.unwrap(), + description: description.unwrap(), + syntax_example: syntax_example.unwrap(), + sql_example, + arguments, + alternative_syntax, + related_udfs, + }) + } +} + +impl Default for DocumentationBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 7cc57523a14df..6ab94c1e841a8 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,13 +28,16 @@ use std::{ use arrow::datatypes::{DataType, Field}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; - use crate::expr::WindowFunction; use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature, + function::WindowFunctionSimplification, Documentation, Expr, PartitionEvaluator, + Signature, }; +use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -148,9 +151,18 @@ impl WindowUDF { self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function - pub fn partition_evaluator_factory(&self) -> Result> { - self.inner.partition_evaluator() + pub fn partition_evaluator_factory( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } /// Returns the field of the final result of evaluating this window function. @@ -172,6 +184,22 @@ impl WindowUDF { pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } + + /// Returns the reversed user-defined window function when the + /// order of evaluation is reversed. + /// + /// See [`WindowUDFImpl::reverse_expr`] for more details. + pub fn reverse_expr(&self) -> ReversedUDWF { + self.inner.reverse_expr() + } + + /// Returns the documentation for this Window UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } impl From for WindowUDF @@ -196,31 +224,54 @@ where /// # Basic Example /// ``` /// # use std::any::Any; +/// # use std::sync::OnceLock; /// # use arrow::datatypes::{DataType, Field}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; -/// use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +/// /// #[derive(Debug, Clone)] /// struct SmoothIt { -/// signature: Signature +/// signature: Signature, /// } /// /// impl SmoothIt { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), /// } /// } /// } /// -/// /// Implement the WindowUDFImpl trait for AddOne +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_ANALYTICAL) +/// .with_description("smooths the windows") +/// .with_syntax_example("smooth_it(2)") +/// .with_argument("arg1", "The int32 number to smooth by") +/// .build() +/// .unwrap() +/// }) +/// } +/// +/// /// Implement the WindowUDFImpl trait for SmoothIt /// impl WindowUDFImpl for SmoothIt { /// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "smooth_it" } /// fn signature(&self) -> &Signature { &self.signature } -/// // The actual implementation would add one to the argument -/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// // The actual implementation would smooth the window +/// fn partition_evaluator( +/// &self, +/// _partition_evaluator_args: PartitionEvaluatorArgs, +/// ) -> Result> { +/// unimplemented!() +/// } /// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { /// if let Some(DataType::Int32) = field_args.get_input_type(0) { /// Ok(Field::new(field_args.name(), DataType::Int32, false)) @@ -228,6 +279,9 @@ where /// plan_err!("smooth_it only accepts Int32 arguments") /// } /// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new WindowUDF from the implementation @@ -256,8 +310,19 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + /// Invoke the function, returning the [`PartitionEvaluator`] instance - fn partition_evaluator(&self) -> Result>; + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result>; /// Returns any aliases (alternate names) for this function. /// @@ -318,6 +383,10 @@ pub trait WindowUDFImpl: Debug + Send + Sync { } /// The [`Field`] of the final result of evaluating this window function. + /// + /// Call `field_args.name()` to get the fully qualified name for defining + /// the [`Field`]. For a complete example see the implementation in the + /// [Basic Example](WindowUDFImpl#basic-example) section. fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. @@ -351,6 +420,32 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Allows customizing the behavior of the user-defined window + /// function when it is evaluated in reverse order. + fn reverse_expr(&self) -> ReversedUDWF { + ReversedUDWF::NotSupported + } + + /// Returns the documentation for this Window UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } +} + +pub enum ReversedUDWF { + /// The result of evaluating the user-defined window function + /// remains identical when reversed. + Identical, + /// A window function which does not support evaluating the result + /// in reverse order. + NotSupported, + /// Customize the user-defined window function for evaluating the + /// result in reverse order. + Reversed(Arc), } impl PartialEq for dyn WindowUDFImpl { @@ -401,8 +496,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } - fn partition_evaluator(&self) -> Result> { - self.inner.partition_evaluator() + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } fn aliases(&self) -> &[String] { @@ -439,6 +544,41 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +// Window UDF doc sections for use in public documentation +pub mod window_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_AGGREGATE, + DOC_SECTION_RANKING, + DOC_SECTION_ANALYTICAL, + ] + } + + pub const DOC_SECTION_AGGREGATE: DocSection = DocSection { + include: true, + label: "Aggregate Functions", + description: Some("All aggregate functions can be used as window functions."), + }; + + pub const DOC_SECTION_RANKING: DocSection = DocSection { + include: true, + label: "Ranking Functions", + description: None, + }; + + pub const DOC_SECTION_ANALYTICAL: DocSection = DocSection { + include: true, + label: "Analytical Functions", + description: None, + }; } #[cfg(test)] @@ -448,6 +588,7 @@ mod test { use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::cmp::Ordering; @@ -479,7 +620,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { @@ -515,7 +659,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1d8eb9445edad..29c62440abb12 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,6 +19,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction}; @@ -38,6 +39,7 @@ use datafusion_common::{ DataFusionError, Result, TableReference, }; +use indexmap::IndexSet; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; @@ -65,9 +67,10 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { "Invalid group by expressions, GroupingSet must be the only expression" ); } - Ok(grouping_set.distinct_expr().len()) + // Groupings sets have an additional interal column for the grouping id + Ok(grouping_set.distinct_expr().len() + 1) } else { - Ok(group_expr.len()) + grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) } } @@ -202,7 +205,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } - // only process mix grouping sets + // Only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -231,7 +234,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .collect::>>()?; - // cross join + // Cross Join let grouping_sets = partial_sets .into_iter() .map(Ok) @@ -260,7 +263,11 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { } Ok(grouping_set.distinct_expr()) } else { - Ok(group_expr.iter().collect()) + Ok(group_expr + .iter() + .collect::>() + .into_iter() + .collect()) } } @@ -335,7 +342,7 @@ fn get_excluded_columns( // Excluded columns should be unique let n_elem = idents.len(); let unique_idents = idents.into_iter().collect::>(); - // if HashSet size, and vector length are different, this means that some of the excluded columns + // If HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); @@ -430,7 +437,10 @@ pub fn expand_qualified_wildcard( return plan_err!("Invalid qualifier {qualifier}"); } - let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); + let qualified_schema = Arc::new(Schema::new_with_metadata( + fields_with_qualified, + schema.metadata().clone(), + )); let qualified_dfschema = DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? .with_functional_dependencies(projected_func_dependencies)?; @@ -459,7 +469,7 @@ pub fn expand_qualified_wildcard( } /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") -/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column +/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_by expr @@ -566,7 +576,7 @@ pub fn compare_sort_expr( Ordering::Equal } -/// group a slice of window expression expr by their order by expressions +/// Group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( window_expr: Vec, ) -> Result)>> { @@ -593,7 +603,7 @@ pub fn group_window_expr_by_sort_keys( /// Collect all deeply nested `Expr::AggregateFunction`. /// They are returned in order of occurrence (depth /// first), with duplicates omitted. -pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { +pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::AggregateFunction { .. }) }) @@ -618,12 +628,15 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +fn find_exprs_in_exprs<'a, F>( + exprs: impl IntoIterator, + test_fn: &F, +) -> Vec where F: Fn(&Expr) -> bool, { exprs - .iter() + .into_iter() .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) .fold(vec![], |mut acc, expr| { if !acc.contains(&expr) { @@ -646,7 +659,7 @@ where if !(exprs.contains(expr)) { exprs.push(expr.clone()) } - // stop recursing down this expr once we find a match + // Stop recursing down this expr once we find a match return Ok(TreeNodeRecursion::Jump); } @@ -665,7 +678,7 @@ where let mut err = Ok(()); expr.apply(|expr| { if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError + // Save the error for later (it may not be a DataFusionError) err = Err(e); Ok(TreeNodeRecursion::Stop) } else { @@ -684,7 +697,7 @@ pub fn exprlist_to_fields<'a>( exprs: impl IntoIterator, plan: &LogicalPlan, ) -> Result, Arc)>> { - // look for exact match in plan's output schema + // Look for exact match in plan's output schema let wildcard_schema = find_base_plan(plan).schema(); let input_schema = plan.schema(); let result = exprs @@ -758,6 +771,15 @@ pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { LogicalPlan::Window(window) => find_base_plan(&window.input), LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection + // We should expand the wildcard expression based on the input plan of the inner Projection. + LogicalPlan::Unnest(unnest) => { + if let LogicalPlan::Projection(projection) = unnest.input.deref() { + find_base_plan(&projection.input) + } else { + input + } + } LogicalPlan::Filter(filter) => { if filter.having { // If a filter is used for a having clause, its input plan is an aggregation. @@ -934,8 +956,8 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes } -/// can this data type be used in hash join equal conditions?? -/// data types here come from function 'equal_rows', if more data types are supported +/// Can this data type be used in hash join equal conditions?? +/// Data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { @@ -959,6 +981,7 @@ pub fn can_hash(data_type: &DataType) -> bool { }, DataType::Utf8 => true, DataType::LargeUtf8 => true, + DataType::Utf8View => true, DataType::Decimal128(_, _) => true, DataType::Date32 => true, DataType::Date64 => true, @@ -1082,6 +1105,54 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& } } +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction(expr: &Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(right); + stack.push(left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(expr), + other => return Some(other), + } + } + None + }) +} + +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(*right); + stack.push(*left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + other => return Some(other), + } + } + None + }) +} + /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// This is often used to "split" filter expressions such as `col1 = 5 @@ -1328,7 +1399,7 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, + col, cube, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; @@ -1343,19 +1414,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1370,28 +1441,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = expr::Sort::new(col("age"), true, true); - let name_desc = expr::Sort::new(col("name"), false, true); - let created_at_desc = expr::Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let age_asc = Sort::new(col("age"), true, true); + let name_desc = Sort::new(col("name"), false, true); + let created_at_desc = Sort::new(col("created_at"), false, true); + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2e8268aa332c..222914315d700 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,11 +23,11 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::{expr::Sort, lit}; +use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{expr::Sort, lit}; - use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -94,7 +94,7 @@ pub struct WindowFrame { } impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, "{} BETWEEN {} AND {}", @@ -119,9 +119,9 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.try_into()?; + let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; let end_bound = match value.end_bound { - Some(value) => value.try_into()?, + Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?, None => WindowFrameBound::CurrentRow, }; @@ -138,6 +138,7 @@ impl TryFrom for WindowFrame { )? } }; + let units = value.units.into(); Ok(Self::new_bounds(units, start_bound, end_bound)) } @@ -334,17 +335,18 @@ impl WindowFrameBound { } } -impl TryFrom for WindowFrameBound { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrameBound) -> Result { +impl WindowFrameBound { + fn try_parse( + value: ast::WindowFrameBound, + units: &ast::WindowFrameUnits, + ) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_frame_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, @@ -352,37 +354,69 @@ impl TryFrom for WindowFrameBound { } } -pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { - Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ast::Value::Number(value, false)) - | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, - ast::Expr::Interval(ast::Interval { - value, - leading_field, - .. - }) => { - let result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return sql_err!(ParserError(format!( - "INTERVAL expression cannot be {e:?}" - ))); +fn convert_frame_bound_to_scalar_value( + v: ast::Expr, + units: &ast::WindowFrameUnits, +) -> Result { + match units { + // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... + ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { + ast::Expr::Value(ast::Value::Number(value, false)) => { + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + }, + ast::Expr::Interval(ast::Interval { + value, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }) => { + let value = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + } + _ => plan_err!( + "Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ), + }, + // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, + // so we use a ScalarValue::Utf8. + ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) => value, + ast::Expr::Interval(ast::Interval { + value, + leading_field, + .. + }) => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + if let Some(leading_field) = leading_field { + format!("{result} {leading_field}") + } else { + result } - }; - if let Some(leading_field) = leading_field { - format!("{result} {leading_field}") - } else { - result } - } - _ => plan_err!( - "Invalid window frame: frame offsets must be non negative integers" - )?, - }))) + _ => plan_err!( + "Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval" + )?, + }))), + } } impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { WindowFrameBound::Preceding(n) => { if n.is_null() { @@ -423,7 +457,7 @@ pub enum WindowFrameUnits { } impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.write_str(match self { WindowFrameUnits::Rows => "ROWS", WindowFrameUnits::Range => "RANGE", @@ -479,8 +513,91 @@ mod tests { ast::Expr::Value(ast::Value::Number("1".to_string(), false)), )))), }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); + + let window_frame = WindowFrame::try_from(window_frame)?; + assert_eq!(window_frame.units, WindowFrameUnits::Rows); + assert_eq!( + window_frame.start_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + ); + assert_eq!( + window_frame.end_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + ); + + Ok(()) + } + + macro_rules! test_bound { + ($unit:ident, $value:expr, $expected:expr) => { + let preceding = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(preceding, WindowFrameBound::Preceding($expected)); + let following = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(following, WindowFrameBound::Following($expected)); + }; + } + + macro_rules! test_bound_err { + ($unit:ident, $value:expr, $expected:expr) => { + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + }; + } + + #[test] + fn test_window_frame_bound_creation() -> Result<()> { + // Unbounded + test_bound!(Rows, None, ScalarValue::Null); + test_bound!(Groups, None, ScalarValue::Null); + test_bound!(Range, None, ScalarValue::Null); + + // Number + let number = Some(Box::new(ast::Expr::Value(ast::Value::Number( + "42".to_string(), + false, + )))); + test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("42".to_string())) + ); + + // Interval + let number = Some(Box::new(ast::Expr::Interval(ast::Interval { + value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + "1".to_string(), + ))), + leading_field: Some(ast::DateTimeField::Day), + fractional_seconds_precision: None, + last_field: None, + leading_precision: None, + }))); + test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("1 DAY".to_string())) + ); + Ok(()) } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a80718147c3a4..be2b6575e2e9c 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -15,73 +15,8 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::ScalarValue; - use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; -/// Create an expression to represent the `rank` window function -pub fn rank() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) -} - -/// Create an expression to represent the `dense_rank` window function -pub fn dense_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::DenseRank, - vec![], - )) -} - -/// Create an expression to represent the `percent_rank` window function -pub fn percent_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::PercentRank, - vec![], - )) -} - -/// Create an expression to represent the `cume_dist` window function -pub fn cume_dist() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) -} - -/// Create an expression to represent the `ntile` window function -pub fn ntile(arg: Expr) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) -} - -/// Create an expression to represent the `lag` window function -pub fn lag( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lag, - vec![arg, shift_offset_lit, default_lit], - )) -} - -/// Create an expression to represent the `lead` window function -pub fn lead( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lead, - vec![arg, shift_offset_lit, default_lit], - )) -} - /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new( diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index e7f31bbfbf2bd..f1d0ead23ab19 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -48,7 +48,7 @@ pub struct WindowAggState { /// Keeps track of how many rows should be generated to be in sync with input record_batch. // (For each row in the input record batch we need to generate a window result). pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition + /// Flag indicating whether we have received all data for this partition pub is_end: bool, } diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs index ee61128979e10..07fa4efc990e5 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs @@ -25,6 +25,7 @@ use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::Arc; /// Specialized implementation of @@ -86,7 +87,7 @@ impl Accumulator for BytesDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } @@ -146,6 +147,6 @@ impl Accumulator for BytesViewDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index d128a8af58eef..405b2c2db7bdd 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -23,6 +23,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; +use std::mem::size_of_val; use std::sync::Arc; use ahash::RandomState; @@ -117,8 +118,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } @@ -206,8 +206,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index e60f68972074e..03e4ef557269f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -23,15 +23,16 @@ pub mod bool_op; pub mod nulls; pub mod prim_op; +use std::mem::{size_of, size_of_val}; + +use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, + compute::take_arrays, datatypes::UInt32Type, }; -use datafusion_common::{ - arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, - ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; @@ -123,9 +124,7 @@ impl AccumulatorState { /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { - self.accumulator.size() - + std::mem::size_of_val(self) - + self.indices.allocated_size() + self.accumulator.size() + size_of_val(self) + self.indices.allocated_size() } } @@ -239,7 +238,7 @@ impl GroupsAccumulatorAdapter { // reorder the values and opt_filter by batch_indices so that // all values for each group are contiguous, then invoke the // accumulator once per group with values - let values = get_arrayref_at_indices(values, &batch_indices)?; + let values = take_arrays(values, &batch_indices, None)?; let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; // invoke each accumulator with the appropriate rows, first @@ -406,6 +405,18 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { ) -> Result> { let num_rows = values[0].len(); + // If there are no rows, return empty arrays + if num_rows == 0 { + // create empty accumulator to get the state types + let empty_state = (self.factory)()?.state()?; + let empty_arrays = empty_state + .into_iter() + .map(|state_val| new_empty_array(&state_val.data_type())) + .collect::>(); + + return Ok(empty_arrays); + } + // Each row has its respective group let mut results = vec![]; for row_idx in 0..num_rows { @@ -453,7 +464,7 @@ pub trait VecAllocExt { impl VecAllocExt for Vec { type T = T; fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index a0475fe8e4464..3efd348937ed4 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -95,7 +95,7 @@ impl NullState { /// /// When value_fn is called it also sets /// - /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value pub fn accumulate( &mut self, group_indices: &[usize], diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 25212f7f0f5ff..6a8946034cbc3 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -15,13 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls +//! [`set_nulls`], other utilities for working with nulls -use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, + BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, +}; use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use std::sync::Arc; /// Sets the validity mask for a `PrimitiveArray` to `nulls` /// replacing any existing null mask +/// +/// See [`set_nulls_dyn`] for a version that works with `Array` pub fn set_nulls( array: PrimitiveArray, nulls: Option, @@ -91,3 +100,105 @@ pub fn filtered_null_mask( let opt_filter = opt_filter.and_then(filter_to_nulls); NullBuffer::union(opt_filter.as_ref(), input.nulls()) } + +/// Applies optional filter to input, returning a new array of the same type +/// with the same data, but with any values that were filtered out set to null +pub fn apply_filter_as_nulls( + input: &dyn Array, + opt_filter: Option<&BooleanArray>, +) -> Result { + let nulls = filtered_null_mask(opt_filter, input); + set_nulls_dyn(input, nulls) +} + +/// Replaces the nulls in the input array with the given `NullBuffer` +/// +/// TODO: replace when upstreamed in arrow-rs: +pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return not_impl_err!("Applying nulls {:?}", input.data_type()); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + + Ok(output) +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 8bbcf756c37c1..078982c983fc7 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; @@ -195,6 +196,6 @@ where } fn size(&self) -> usize { - self.values.capacity() * std::mem::size_of::() + self.null_state.size() + self.values.capacity() * size_of::() + self.null_state.size() } } diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 620a68e83ecdc..786d7ea3e3610 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -33,6 +33,7 @@ use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; +use std::mem::{size_of, size_of_val}; pub const DEFAULT_MAX_SIZE: usize = 100; @@ -203,8 +204,7 @@ impl TDigest { /// Size in bytes including `Self`. pub fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.centroids.capacity()) + size_of_val(self) + (size_of::() * self.centroids.capacity()) } } @@ -644,7 +644,9 @@ impl TDigest { let max = cast_scalar_f64!(&state[3]); let min = cast_scalar_f64!(&state[4]); - assert!(max.total_cmp(&min).is_ge()); + if min.is_finite() && max.is_finite() { + assert!(max.total_cmp(&min).is_ge()); + } Self { max_size, diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index d78f68a2604e7..37e4c7f4a5ad8 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -48,9 +48,9 @@ datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } paste = "1.0.14" -sqlparser = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index cf8217fe981de..1df106feb4d38 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -31,13 +31,17 @@ use datafusion_common::ScalarValue; use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; +use std::sync::OnceLock; make_udaf_expr_and_func!( ApproxDistinct, approx_distinct, @@ -303,4 +307,33 @@ impl AggregateUDFImpl for ApproxDistinct { }; Ok(accumulator) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_distinct_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_distinct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm.", + ) + .with_syntax_example("approx_distinct(expression)") + .with_sql_example(r#"```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 7a7b12432544a..96609622a51e4 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -19,15 +19,19 @@ use std::any::Any; use std::fmt::Debug; +use std::sync::OnceLock; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; use crate::approx_percentile_cont::ApproxPercentileAccumulator; @@ -116,4 +120,33 @@ impl AggregateUDFImpl for ApproxMedian { acc_args.exprs[0].data_type(acc_args.schema)?, ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_median_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_median_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + ) + .with_syntax_example("approx_median(expression)") + .with_sql_example(r#"```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 5578aebbf4031..53fcfd641ddfc 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -17,7 +17,8 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; @@ -34,12 +35,13 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, + TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, @@ -268,6 +270,36 @@ impl AggregateUDFImpl for ApproxPercentileCont { } Ok(arg_types[0].clone()) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_percentile_cont_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_percentile_cont_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate percentile of input values using the t-digest algorithm.", + ) + .with_syntax_example("approx_percentile_cont(expression, percentile, centroids)") + .with_sql_example(r#"```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +```"#) + .with_standard_argument("expression", None) + .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") + .with_argument("centroids", "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory.") + .build() + .unwrap() + }) } #[derive(Debug)] @@ -455,10 +487,9 @@ impl Accumulator for ApproxPercentileAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.digest.size() - - std::mem::size_of_val(&self.digest) + size_of_val(self) + self.digest.size() - size_of_val(&self.digest) + self.return_type.size() - - std::mem::size_of_val(&self.return_type) + - size_of_val(&self.return_type) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index fee67ba1623db..5458d0f792b92 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,7 +17,8 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; use arrow::{ array::ArrayRef, @@ -26,10 +27,13 @@ use arrow::{ use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, +}; use datafusion_functions_aggregate_common::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; @@ -151,6 +155,37 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_percentile_cont_with_weight_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_percentile_cont_with_weight_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the weighted approximate percentile of input values using the t-digest algorithm.", + ) + .with_syntax_example("approx_percentile_cont_with_weight(expression, weight, percentile)") + .with_sql_example(r#"```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .with_argument("weight", "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators.") + .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") + .build() + .unwrap() + }) } #[derive(Debug)] @@ -205,8 +240,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - - std::mem::size_of_val(&self.approx_percentile_cont_accumulator) + size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator) + self.approx_percentile_cont_accumulator.size() } } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 15146fc4a2d89..b3e04c5584ef8 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -25,15 +25,17 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::AggregateUDFImpl; use datafusion_expr::{Accumulator, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, Documentation}; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; -use std::sync::Arc; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; make_udaf_expr_and_func!( ArrayAgg, @@ -142,6 +144,35 @@ impl AggregateUDFImpl for ArrayAgg { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_agg_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_agg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.", + ) + .with_syntax_example("array_agg(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -215,15 +246,15 @@ impl Accumulator for ArrayAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .map(|arr| arr.get_array_memory_size()) .sum::() + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -288,10 +319,10 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) + size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - size_of_val(&self.values) + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -456,25 +487,23 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddad76a8734b0..710b7e69ac5c3 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -18,8 +18,8 @@ //! Defines `Avg` & `Mean` aggregate & accumulators use arrow::array::{ - self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, + BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; use arrow::compute::sum; @@ -28,12 +28,14 @@ use arrow::datatypes::{ Float64Type, UInt64Type, }; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, + ReversedUDAF, Signature, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; @@ -45,7 +47,8 @@ use datafusion_functions_aggregate_common::utils::DecimalAverager; use log::debug; use std::any::Any; use std::fmt::Debug; -use std::sync::Arc; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; make_udaf_expr_and_func!( Avg, @@ -235,6 +238,36 @@ impl AggregateUDFImpl for Avg { } coerce_avg_type(self.name(), arg_types) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_avg_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_avg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the average of numeric values in the specified column.", + ) + .with_syntax_example("avg(expression)") + .with_sql_example( + r#"```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } /// An accumulator to compute the average @@ -262,7 +295,7 @@ impl Accumulator for AvgAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -340,7 +373,7 @@ impl Accumulator for DecimalAvgAccumu } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -439,7 +472,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -522,7 +555,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -582,7 +615,6 @@ where } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index aa65062e3330c..249ff02e72221 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; @@ -35,11 +36,14 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::INTEGERS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; +use std::sync::OnceLock; /// This macro helps create group accumulators based on bitwise operations typically used internally /// and might not be necessary for users to call directly. @@ -110,8 +114,9 @@ macro_rules! downcast_bitwise_accumulator { /// `EXPR_FN` identifier used to name the generated expression function. /// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function. /// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed. +/// `DOCUMENTATION` documentation for the UDAF macro_rules! make_bitwise_udaf_expr_and_func { - ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr) => { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => { make_udaf_expr!( $EXPR_FN, expr_x, @@ -125,14 +130,73 @@ macro_rules! make_bitwise_udaf_expr_and_func { create_func!( $EXPR_FN, $AGGREGATE_UDF_FN, - BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN)) + BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION) ); }; } -make_bitwise_udaf_expr_and_func!(bit_and, bit_and_udaf, BitwiseOperationType::And); -make_bitwise_udaf_expr_and_func!(bit_or, bit_or_udaf, BitwiseOperationType::Or); -make_bitwise_udaf_expr_and_func!(bit_xor, bit_xor_udaf, BitwiseOperationType::Xor); +static BIT_AND_DOC: OnceLock = OnceLock::new(); + +fn get_bit_and_doc() -> &'static Documentation { + BIT_AND_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Computes the bitwise AND of all non-null input values.") + .with_syntax_example("bit_and(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +static BIT_OR_DOC: OnceLock = OnceLock::new(); + +fn get_bit_or_doc() -> &'static Documentation { + BIT_OR_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Computes the bitwise OR of all non-null input values.") + .with_syntax_example("bit_or(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +static BIT_XOR_DOC: OnceLock = OnceLock::new(); + +fn get_bit_xor_doc() -> &'static Documentation { + BIT_XOR_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Computes the bitwise exclusive OR of all non-null input values.", + ) + .with_syntax_example("bit_xor(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +make_bitwise_udaf_expr_and_func!( + bit_and, + bit_and_udaf, + BitwiseOperationType::And, + get_bit_and_doc() +); +make_bitwise_udaf_expr_and_func!( + bit_or, + bit_or_udaf, + BitwiseOperationType::Or, + get_bit_or_doc() +); +make_bitwise_udaf_expr_and_func!( + bit_xor, + bit_xor_udaf, + BitwiseOperationType::Xor, + get_bit_xor_doc() +); /// The different types of bitwise operations that can be performed. #[derive(Debug, Clone, Eq, PartialEq)] @@ -155,14 +219,20 @@ struct BitwiseOperation { /// `operation` indicates the type of bitwise operation to be performed. operation: BitwiseOperationType, func_name: &'static str, + documentation: &'static Documentation, } impl BitwiseOperation { - pub fn new(operator: BitwiseOperationType, func_name: &'static str) -> Self { + pub fn new( + operator: BitwiseOperationType, + func_name: &'static str, + documentation: &'static Documentation, + ) -> Self { Self { operation: operator, signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable), func_name, + documentation, } } } @@ -239,6 +309,10 @@ impl AggregateUDFImpl for BitwiseOperation { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + Some(self.documentation) + } } struct BitAndAccumulator { @@ -274,7 +348,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -319,7 +393,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -373,7 +447,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -436,8 +510,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 7cc7d9ff7fec3..87293ccfa21f5 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -18,6 +18,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; +use std::mem::size_of_val; +use std::sync::OnceLock; use arrow::array::ArrayRef; use arrow::array::BooleanArray; @@ -29,10 +31,12 @@ use arrow::datatypes::Field; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; @@ -172,6 +176,36 @@ impl AggregateUDFImpl for BoolAnd { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bool_and_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_bool_and_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns true if all non-null input values are true, otherwise false.", + ) + .with_syntax_example("bool_and(expression)") + .with_sql_example( + r#"```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug, Default)] @@ -196,7 +230,7 @@ impl Accumulator for BoolAndAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -293,6 +327,34 @@ impl AggregateUDFImpl for BoolOr { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bool_or_doc()) + } +} + +fn get_bool_or_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns true if any non-null input value is true, otherwise false.", + ) + .with_syntax_example("bool_or(expression)") + .with_sql_example( + r#"```sql +> SELECT bool_or(column_name) FROM table_name; ++----------------------------+ +| bool_or(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug, Default)] @@ -317,7 +379,7 @@ impl Accumulator for BoolOrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 88f01b06d2d9b..187a43ecbea3c 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -19,7 +19,8 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::Arc; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; use arrow::compute::{and, filter, is_not_null}; use arrow::{ @@ -30,11 +31,12 @@ use arrow::{ use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, utils::format_state_name, - Accumulator, AggregateUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -107,6 +109,37 @@ impl AggregateUDFImpl for Correlation { ), ]) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_corr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_corr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the coefficient of correlation between two numeric values.", + ) + .with_syntax_example("corr(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) } /// An accumulator to compute correlation @@ -172,11 +205,10 @@ impl Accumulator for CorrelationAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) - + self.covar.size() - - std::mem::size_of_val(&self.stddev1) + size_of_val(self) - size_of_val(&self.covar) + self.covar.size() + - size_of_val(&self.stddev1) + self.stddev1.size() - - std::mem::size_of_val(&self.stddev2) + - size_of_val(&self.stddev2) + self.stddev2.size() } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 417e28e72a71f..bade589a908a7 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,10 +16,14 @@ // under the License. use ahash::RandomState; +use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_physical_expr::expressions; use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; -use std::{fmt::Debug, sync::Arc}; +use std::sync::{Arc, OnceLock}; use arrow::{ array::{ArrayRef, AsArray}, @@ -41,12 +45,13 @@ use arrow::{ use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - EmitTo, GroupsAccumulator, Signature, Volatility, + Documentation, EmitTo, GroupsAccumulator, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; +use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -54,6 +59,7 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::{ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::binary_map::OutputType; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; make_udaf_expr_and_func!( Count, count, @@ -291,6 +297,71 @@ impl AggregateUDFImpl for Count { fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Int64(Some(0))) } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + if statistics_args.is_distinct { + return None; + } + if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + let current_val = &statistics_args.statistics.column_statistics + [col_expr.index()] + .null_count; + if let &Precision::Exact(val) = current_val { + return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); + } + } else if let Some(lit_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(ScalarValue::Int64(Some(num_rows as i64))); + } + } + } + } + None + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.", + ) + .with_syntax_example("count(expression)") + .with_sql_example(r#"```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +```"#) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -324,7 +395,7 @@ impl Accumulator for CountAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = downcast_value!(states[0], Int64Array); - let delta = &arrow::compute::sum(counts); + let delta = &compute::sum(counts); if let Some(d) = delta { self.count += *d; } @@ -340,7 +411,7 @@ impl Accumulator for CountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -513,7 +584,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() } } @@ -557,28 +628,28 @@ impl DistinctCountAccumulator { // number of batches This method is faster than .full_size(), however it is // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .unwrap_or(0) - + std::mem::size_of::() + + size_of::() } // calculates the size as accurately as possible. Note that calling this // method is expensive fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .sum::() - + std::mem::size_of::() + + size_of::() } } @@ -645,3 +716,17 @@ impl Accumulator for DistinctCountAccumulator { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + + #[test] + fn count_accumulator_nulls() -> Result<()> { + let mut accumulator = CountAccumulator::new(); + accumulator.update_batch(&[Arc::new(NullArray::new(10))])?; + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d0abb079ef155..063aaa92059dd 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -18,6 +18,8 @@ //! [`CovarianceSample`]: covariance sample aggregations. use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::OnceLock; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, @@ -29,11 +31,12 @@ use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, utils::format_state_name, - Accumulator, AggregateUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -124,6 +127,35 @@ impl AggregateUDFImpl for CovarianceSample { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_covar_samp_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_covar_samp_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description("Returns the sample covariance of a set of number pairs.") + .with_syntax_example("covar_samp(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) } pub struct CovariancePopulation { @@ -193,6 +225,35 @@ impl AggregateUDFImpl for CovariancePopulation { StatsType::Population, )?)) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_covar_pop_doc()) + } +} + +fn get_covar_pop_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the population covariance of a set of number pairs.", + ) + .with_syntax_example("covar_pop(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT covar_pop(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_pop(column1, column2) | ++-----------------------------------+ +| 7.63 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) } /// An accumulator to compute covariance @@ -388,6 +449,6 @@ impl Accumulator for CovarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 30f5d5b07561b..da3fc62f8c8c6 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,20 +19,22 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::Arc; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, - Signature, SortExpr, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr, + ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -165,6 +167,35 @@ impl AggregateUDFImpl for FirstValue { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_first_value_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_first_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + ) + .with_syntax_example("first_value(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -310,7 +341,7 @@ impl Accumulator for FirstValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - get_arrayref_at_indices(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { let first_row = get_row_at_idx(&ordered_states, 0)?; @@ -335,10 +366,10 @@ impl Accumulator for FirstValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + size_of_val(self) - size_of_val(&self.first) + self.first.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -466,6 +497,33 @@ impl AggregateUDFImpl for LastValue { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_last_value_doc()) + } +} + +fn get_last_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + ) + .with_syntax_example("last_value(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT last_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| last_value(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| last_element | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -613,7 +671,7 @@ impl Accumulator for LastValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - get_arrayref_at_indices(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { @@ -641,10 +699,10 @@ impl Accumulator for LastValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + size_of_val(self) - size_of_val(&self.last) + self.last.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -738,7 +796,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); @@ -768,7 +826,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 6fb7c3800f4ed..27949aa3df277 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -19,14 +19,18 @@ use std::any::Any; use std::fmt; +use std::sync::OnceLock; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; make_udaf_expr_and_func!( Grouping, @@ -41,7 +45,7 @@ pub struct Grouping { } impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Grouping") .field("name", &self.name()) .field("signature", &self.signature) @@ -59,7 +63,7 @@ impl Grouping { /// Create a new GROUPING aggregate function. pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::variadic_any(Volatility::Immutable), } } } @@ -94,4 +98,37 @@ impl AggregateUDFImpl for Grouping { "physical plan is not yet implemented for GROUPING aggregate function" ) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_grouping_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_grouping_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + ) + .with_syntax_example("grouping(expression)") + .with_sql_example(r#"```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +```"#, + ) + .with_argument("expression", "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions-aggregate/src/kurtosis_pop.rs b/datafusion/functions-aggregate/src/kurtosis_pop.rs deleted file mode 100644 index ac173a0ee5795..0000000000000 --- a/datafusion/functions-aggregate/src/kurtosis_pop.rs +++ /dev/null @@ -1,190 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{Array, ArrayRef, Float64Array, UInt64Array}; -use arrow_schema::{DataType, Field}; -use datafusion_common::cast::as_float64_array; -use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion_functions_aggregate_common::accumulator::{ - AccumulatorArgs, StateFieldsArgs, -}; -use std::any::Any; -use std::fmt::Debug; - -make_udaf_expr_and_func!( - KurtosisPopFunction, - kurtosis_pop, - x, - "Calculates the excess kurtosis (Fisher’s definition) without bias correction.", - kurtosis_pop_udaf -); - -pub struct KurtosisPopFunction { - signature: Signature, -} - -impl Debug for KurtosisPopFunction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("KurtosisPopFunction") - .field("signature", &self.signature) - .finish() - } -} - -impl Default for KurtosisPopFunction { - fn default() -> Self { - Self::new() - } -} - -impl KurtosisPopFunction { - pub fn new() -> Self { - Self { - signature: Signature::coercible( - vec![DataType::Float64], - Volatility::Immutable, - ), - } - } -} - -impl AggregateUDFImpl for KurtosisPopFunction { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "kurtosis_pop" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new("count", DataType::UInt64, true), - Field::new("sum", DataType::Float64, true), - Field::new("sum_sqr", DataType::Float64, true), - Field::new("sum_cub", DataType::Float64, true), - Field::new("sum_four", DataType::Float64, true), - ]) - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(KurtosisPopAccumulator::new())) - } -} - -/// Accumulator for calculating the excess kurtosis (Fisher’s definition) without bias correction. -/// This implementation follows the [DuckDB implementation]: -/// -#[derive(Debug, Default)] -pub struct KurtosisPopAccumulator { - count: u64, - sum: f64, - sum_sqr: f64, - sum_cub: f64, - sum_four: f64, -} - -impl KurtosisPopAccumulator { - pub fn new() -> Self { - Self { - count: 0, - sum: 0.0, - sum_sqr: 0.0, - sum_cub: 0.0, - sum_four: 0.0, - } - } -} - -impl Accumulator for KurtosisPopAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = as_float64_array(&values[0])?; - for value in array.iter().flatten() { - self.count += 1; - self.sum += value; - self.sum_sqr += value.powi(2); - self.sum_cub += value.powi(3); - self.sum_four += value.powi(4); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let sums = downcast_value!(states[1], Float64Array); - let sum_sqrs = downcast_value!(states[2], Float64Array); - let sum_cubs = downcast_value!(states[3], Float64Array); - let sum_fours = downcast_value!(states[4], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0 { - continue; - } - self.count += c; - self.sum += sums.value(i); - self.sum_sqr += sum_sqrs.value(i); - self.sum_cub += sum_cubs.value(i); - self.sum_four += sum_fours.value(i); - } - - Ok(()) - } - - fn evaluate(&mut self) -> Result { - if self.count < 1 { - return Ok(ScalarValue::Float64(None)); - } - - let count_64 = 1_f64 / self.count as f64; - let m4 = count_64 - * (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64 - + 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2) - - 3.0 * self.sum.powi(4) * count_64.powi(3)); - - let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64; - if m2 <= 0.0 { - return Ok(ScalarValue::Float64(None)); - } - - let target = m4 / (m2.powi(2)) - 3.0; - Ok(ScalarValue::Float64(Some(target))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.sum), - ScalarValue::from(self.sum_sqr), - ScalarValue::from(self.sum_cub), - ScalarValue::from(self.sum_four), - ]) - } -} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 60e2602eb6eda..ca0276d326a49 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -78,7 +78,6 @@ pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; -pub mod kurtosis_pop; pub mod nth_value; pub mod string_agg; @@ -171,7 +170,6 @@ pub fn all_default_aggregate_functions() -> Vec> { average::avg_udaf(), grouping::grouping_udaf(), nth_value::nth_value_udaf(), - kurtosis_pop::kurtosis_pop_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 573b9fd5bdb2f..ffb5183278e67 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,23 +15,6 @@ // specific language governing permissions and limitations // under the License. -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - macro_rules! make_udaf_expr { ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 7dd0de14c3c0c..ff0a930d490bf 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -16,8 +16,9 @@ // under the License. use std::collections::HashSet; -use std::fmt::Formatter; -use std::{fmt::Debug, sync::Arc}; +use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; use arrow::array::{downcast_integer, ArrowNumericType}; use arrow::{ @@ -33,10 +34,11 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::ArrowNativeType; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - Signature, Volatility, + Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::utils::Hashable; @@ -61,7 +63,7 @@ pub struct Median { } impl Debug for Median { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { f.debug_struct("Median") .field("name", &self.name()) .field("signature", &self.signature) @@ -152,6 +154,34 @@ impl AggregateUDFImpl for Median { fn aliases(&self) -> &[String] { &[] } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_median_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_median_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the median value in the specified column.") + .with_syntax_example("median(expression)") + .with_sql_example( + r#"```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } /// The median accumulator accumulates the raw input values @@ -166,7 +196,7 @@ struct MedianAccumulator { all_values: Vec, } -impl std::fmt::Debug for MedianAccumulator { +impl Debug for MedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MedianAccumulator({})", self.data_type) } @@ -206,8 +236,7 @@ impl Accumulator for MedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.all_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.all_values.capacity() * size_of::() } } @@ -223,7 +252,7 @@ struct DistinctMedianAccumulator { distinct_values: HashSet>, } -impl std::fmt::Debug for DistinctMedianAccumulator { +impl Debug for DistinctMedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DistinctMedianAccumulator({})", self.data_type) } @@ -278,8 +307,7 @@ impl Accumulator for DistinctMedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.distinct_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.distinct_values.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c8..b4256508e3515 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -15,22 +15,9 @@ // under the License. //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function -//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +mod min_max_bytes; use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, @@ -49,10 +36,13 @@ use arrow::datatypes::{ UInt8Type, }; use arrow_schema::IntervalUnit; +use datafusion_common::stats::Precision; use datafusion_common::{ - downcast_value, exec_err, internal_err, DataFusionError, Result, + downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_physical_expr::expressions; use std::fmt::Debug; use arrow::datatypes::i256; @@ -62,13 +52,17 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; +use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, + Volatility, }; +use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use half::f16; +use std::mem::size_of_val; use std::ops::Deref; +use std::sync::OnceLock; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -114,7 +108,7 @@ impl Default for Max { /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_max_accumulator { +macro_rules! primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { @@ -133,7 +127,7 @@ macro_rules! instantiate_max_accumulator { /// /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_min_accumulator { +macro_rules! primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { @@ -147,6 +141,54 @@ macro_rules! instantiate_min_accumulator { }}; } +trait FromColumnStatistics { + fn value_from_column_statistics( + &self, + stats: &ColumnStatistics, + ) -> Option; + + fn value_from_statistics( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::try_from(statistics_args.return_type).ok(), + value if value > 0 => { + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + return self.value_from_column_statistics( + &col_stats[col_expr.index()], + ); + } + } + } + _ => {} + } + } + None + } +} + +impl FromColumnStatistics for Max { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.max_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Max { fn as_any(&self) -> &dyn std::any::Any { self @@ -193,6 +235,12 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -204,58 +252,58 @@ impl AggregateUDFImpl for Max { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_max_accumulator!(data_type, f16, Float16Type) + primitive_max_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_max_accumulator!(data_type, f32, Float32Type) + primitive_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(data_type, f64, Float64Type) + primitive_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(data_type, i32, Time32SecondType) + primitive_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) + primitive_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) + primitive_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampSecondType) + primitive_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(data_type, i128, Decimal128Type) + primitive_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(data_type, i256, Decimal256Type) + primitive_max_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), @@ -272,6 +320,7 @@ impl AggregateUDFImpl for Max { fn is_descending(&self) -> Option { Some(true) } + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } @@ -282,6 +331,37 @@ impl AggregateUDFImpl for Max { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_max_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_max_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the maximum value in the specified column.") + .with_syntax_example("max(expression)") + .with_sql_example( + r#"```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } // Statically-typed version of min/max(array) -> ScalarValue for string types @@ -844,7 +924,7 @@ impl Accumulator for MaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } @@ -903,7 +983,7 @@ impl Accumulator for SlidingMaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } @@ -926,6 +1006,20 @@ impl Default for Min { } } +impl FromColumnStatistics for Min { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.min_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Min { fn as_any(&self) -> &dyn std::any::Any { self @@ -972,6 +1066,12 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -983,58 +1083,58 @@ impl AggregateUDFImpl for Min { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_min_accumulator!(data_type, f16, Float16Type) + primitive_min_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_min_accumulator!(data_type, f32, Float32Type) + primitive_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(data_type, f64, Float64Type) + primitive_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(data_type, i32, Time32SecondType) + primitive_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) + primitive_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) + primitive_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampSecondType) + primitive_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_min_accumulator!(data_type, i128, Decimal128Type) + primitive_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(data_type, i256, Decimal256Type) + primitive_min_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), @@ -1052,6 +1152,9 @@ impl AggregateUDFImpl for Min { Some(false) } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } @@ -1063,7 +1166,34 @@ impl AggregateUDFImpl for Min { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_min_doc()) + } } + +fn get_min_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the minimum value in the specified column.") + .with_syntax_example("min(expression)") + .with_sql_example( + r#"```sql +> SELECT min(column_name) FROM table_name; ++----------------------+ +| min(column_name) | ++----------------------+ +| 12 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + /// An accumulator to compute the minimum value #[derive(Debug)] pub struct MinAccumulator { @@ -1102,7 +1232,7 @@ impl Accumulator for MinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } @@ -1165,30 +1295,28 @@ impl Accumulator for SlidingMinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } -// -// Moving min and moving max -// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. - -// Keep track of the minimum or maximum value in a sliding window. -// -// `moving min max` provides one data structure for keeping track of the -// minimum value and one for keeping track of the maximum value in a sliding -// window. -// -// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, -// push to this stack all elements popped from first stack while updating their current min/max. Now pop from -// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, -// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. -// -// The complexity of the operations are -// - O(1) for getting the minimum/maximum -// - O(1) for push -// - amortized O(1) for pop - +/// Keep track of the minimum value in a sliding window. +/// +/// The implementation is taken from +/// +/// `moving min max` provides one data structure for keeping track of the +/// minimum value and one for keeping track of the maximum value in a sliding +/// window. +/// +/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +/// +/// The complexity of the operations are +/// - O(1) for getting the minimum/maximum +/// - O(1) for push +/// - amortized O(1) for pop +/// /// ``` /// # use datafusion_functions_aggregate::min_max::MovingMin; /// let mut moving_min = MovingMin::::new(); @@ -1304,6 +1432,11 @@ impl MovingMin { self.len() == 0 } } + +/// Keep track of the maximum value in a sliding window. +/// +/// See [`MovingMin`] for more details. +/// /// ``` /// # use datafusion_functions_aggregate::min_max::MovingMax; /// let mut moving_max = MovingMax::::new(); diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs new file mode 100644 index 0000000000000..501454edf77ca --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -0,0 +1,515 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, +}; +use arrow_schema::DataType; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use std::mem::size_of; +use std::sync::Arc; + +/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], +/// [`BinaryArray`], [`StringViewArray`], etc) +/// +/// This implementation dispatches to the appropriate specialized code in +/// [`MinMaxBytesState`] based on data type and comparison function +/// +/// [`StringArray`]: arrow::array::StringArray +/// [`BinaryArray`]: arrow::array::BinaryArray +/// [`StringViewArray`]: arrow::array::StringViewArray +#[derive(Debug)] +pub(crate) struct MinMaxBytesAccumulator { + /// Inner data storage. + inner: MinMaxBytesState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxBytesAccumulator { + /// Create a new accumulator for computing `min(val)` + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: true, + } + } + + /// Create a new accumulator fo computing `max(val)` + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxBytesAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + // dispatch to appropriate kernel / specialized implementation + fn string_min(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + fn str_to_bytes<'a>( + it: impl Iterator>, + ) -> impl Iterator> { + it.map(|s| s.map(|s| s.as_bytes())) + } + + match (self.is_min, &self.inner.data_type) { + // Utf8/LargeUtf8/Utf8View Min + (true, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_min, + ), + + // Utf8/LargeUtf8/Utf8View Max + (false, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_max, + ), + + // Binary/LargeBinary/BinaryView Min + (true, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_min, + ), + + // Binary/LargeBinary/BinaryView Max + (false, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_max, + ), + + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); + + // Convert the Vec of bytes to a vec of Strings (at no cost) + fn bytes_to_str( + min_maxes: Vec>>, + ) -> impl Iterator> { + min_maxes.into_iter().map(|opt| { + opt.map(|bytes| { + // Safety: only called on data added from update_batch which ensures + // the input type matched the output type + unsafe { String::from_utf8_unchecked(bytes) } + }) + }) + } + + let result: ArrayRef = match self.inner.data_type { + DataType::Utf8 => { + let mut builder = + StringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Utf8View => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = StringViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Binary => { + let mut builder = + BinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::BinaryView => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + _ => { + return internal_err!( + "Unexpected data type for MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// Returns the block size in (contiguous buffer size) to use +/// for a given data capacity (total string length) +/// +/// This is a heuristic to avoid allocating too many small buffers +fn capacity_to_view_block_size(data_capacity: usize) -> u32 { + let max_block_size = 2 * 1024 * 1024; + if let Ok(block_size) = u32::try_from(data_capacity) { + block_size.min(max_block_size) + } else { + max_block_size + } +} + +/// Stores internal Min/Max state for "bytes" types. +/// +/// This implementation is general and stores the minimum/maximum for each +/// groups in an individual byte array, which balances allocations and memory +/// fragmentation (aka garbage). +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌────▶│Option> (["A"]) │───────────▶ "A" +/// │ 0 │────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │─────────▶│Option> (["Z"]) │───────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │─────────▶│Option> (["A"]) │────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │────┐ ┌────────────────────────────────┐ +/// └─────┘ └────▶│Option> (["Q"]) │────────────▶ "Q" +/// └────────────────────────────────┘ +/// +/// min_max: Vec> +/// ``` +/// +/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially +/// more efficient implementations (e.g. by managing a string data buffer +/// directly), but then garbage collection, memory management, and final array +/// construction becomes more complex. +/// +/// See discussion on +#[derive(Debug)] +struct MinMaxBytesState { + /// The minimum/maximum value for each group + min_max: Vec>>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone, Copy)] +enum MinMaxLocation<'a> { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(&'a [u8]), +} + +/// Implement the MinMaxBytesAccumulator with a comparison function +/// for comparing strings +impl MinMaxBytesState { + /// Create a new MinMaxBytesAccumulator + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &[u8]) { + match self.min_max[group_index].as_mut() { + None => { + self.min_max[group_index] = Some(new_val.to_vec()); + self.total_data_bytes += new_val.len(); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.len(); + self.total_data_bytes += new_val.len(); + existing_val.clear(); + existing_val.extend_from_slice(new_val); + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch<'a, F, I>( + &mut self, + iter: I, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&[u8], &[u8]) -> bool + Send + Sync, + I: IntoIterator>, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owne values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { + let group_index = *group_index; + let Some(new_val) = new_val else { + continue; // skip nulls + }; + + let existing_val = match locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(exising_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + exising_val.as_ref() + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>>() + } +} diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index bbfe56914c910..2a1778d8b232b 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -20,18 +20,20 @@ use std::any::Any; use std::collections::VecDeque; -use std::sync::Arc; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, Signature, - SortExpr, Volatility, + lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF, + Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; @@ -161,6 +163,40 @@ impl AggregateUDFImpl for NthValueAgg { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Reversed(nth_value_udaf()) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nth_value_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nth_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the nth value in a group of values.", + ) + .with_syntax_example("nth_value(expression, n ORDER BY expression)") + .with_sql_example(r#"```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +```"#) + .with_argument("expression", "The column or expression to retrieve the nth value from.") + .with_argument("n", "The position (nth) of the value to retrieve, based on the ordering.") + .build() + .unwrap() + }) } #[derive(Debug)] @@ -343,25 +379,23 @@ impl Accumulator for NthValueAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec_deque(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 390a769aca7f8..bf1e81949d23a 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,9 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::fmt::Debug; - use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -29,10 +26,18 @@ use arrow::{ }; use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::OnceLock; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -76,23 +81,7 @@ impl Regr { } } -/* -#[derive(Debug)] -pub struct Regr { - name: String, - regr_type: RegrType, - expr_y: Arc, - expr_x: Arc, -} - -impl Regr { - pub fn get_regr_type(&self) -> RegrType { - self.regr_type.clone() - } -} -*/ - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum RegrType { /// Variant for `regr_slope` aggregate expression @@ -135,6 +124,148 @@ pub enum RegrType { SXY, } +impl RegrType { + /// return the documentation for the `RegrType` + fn documentation(&self) -> Option<&Documentation> { + get_regr_docs().get(self) + } +} + +static DOCUMENTATION: OnceLock> = OnceLock::new(); +fn get_regr_docs() -> &'static HashMap { + DOCUMENTATION.get_or_init(|| { + let mut hash_map = HashMap::new(); + hash_map.insert( + RegrType::Slope, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \ + Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.", + ) + .with_syntax_example("regr_slope(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Intercept, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \ + this function returns b.", + ) + .with_syntax_example("regr_intercept(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Count, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Counts the number of non-null paired data points.", + ) + .with_syntax_example("regr_count(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::R2, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the square of the correlation coefficient between the independent and dependent variables.", + ) + .with_syntax_example("regr_r2(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the independent variable (input) expression_x for the non-null paired data points.", + ) + .with_syntax_example("regr_avgx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.", + ) + .with_syntax_example("regr_avgy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the independent variable.", + ) + .with_syntax_example("regr_sxx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SYY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the dependent variable.", + ) + .with_syntax_example("regr_syy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of products of paired data points.", + ) + .with_syntax_example("regr_sxy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + hash_map + }) +} + impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self @@ -198,22 +329,11 @@ impl AggregateUDFImpl for Regr { ), ]) } -} -/* -impl PartialEq for Regr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.expr_y.eq(&x.expr_y) - && self.expr_x.eq(&x.expr_x) - }) - .unwrap_or(false) + fn documentation(&self) -> Option<&Documentation> { + self.regr_type.documentation() } } -*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -495,6 +615,6 @@ impl Accumulator for RegrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index a25ab5e319915..355d1d5ad2db9 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,17 +19,20 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::mem::align_of_val; +use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -132,6 +135,34 @@ impl AggregateUDFImpl for Stddev { ) -> Result> { Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_stddev_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_stddev_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description("Returns the standard deviation of a set of numbers.") + .with_syntax_example("stddev(expression)") + .with_sql_example( + r#"```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } make_udaf_expr_and_func!( @@ -228,6 +259,34 @@ impl AggregateUDFImpl for StddevPop { StatsType::Population, ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_stddev_pop_doc()) + } +} + +fn get_stddev_pop_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the population standard deviation of a set of numbers.", + ) + .with_syntax_example("stddev_pop(expression)") + .with_sql_example( + r#"```sql +> SELECT stddev_pop(column_name) FROM table_name; ++--------------------------+ +| stddev_pop(column_name) | ++--------------------------+ +| 10.56 | ++--------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } /// An accumulator to compute the average @@ -285,8 +344,7 @@ impl Accumulator for StddevAccumulator { } fn size(&self) -> usize { - std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) - + self.variance.size() + align_of_val(self) - align_of_val(&self.variance) + self.variance.size() } fn supports_retract_batch(&self) -> bool { diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7e9a37e23ad6..68267b9f72c7d 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -22,12 +22,15 @@ use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr::expressions::Literal; use std::any::Any; +use std::mem::size_of_val; +use std::sync::OnceLock; make_udaf_expr_and_func!( StringAgg, @@ -98,6 +101,37 @@ impl AggregateUDFImpl for StringAgg { not_impl_err!("expect literal") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_string_agg_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_string_agg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Concatenates the values of string expressions and places separator values between them." + ) + .with_syntax_example("string_agg(expression, delimiter)") + .with_sql_example(r#"```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +```"#, + ) + .with_argument("expression", "The string expression to concatenate. Can be a column or any valid string expression.") + .with_argument("delimiter", "A literal string used as a separator between the concatenated values.") + .build() + .unwrap() + }) } #[derive(Debug)] @@ -146,7 +180,7 @@ impl Accumulator for StringAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + self.delimiter.capacity() } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 7e40c1bd17a8d..6ad376db4fb9c 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -21,6 +21,8 @@ use ahash::RandomState; use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; +use std::mem::{size_of, size_of_val}; +use std::sync::OnceLock; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; @@ -33,11 +35,13 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::utils::Hashable; @@ -233,6 +237,34 @@ impl AggregateUDFImpl for Sum { fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sum_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sum_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the sum of all values in the specified column.") + .with_syntax_example("sum(expression)") + .with_sql_example( + r#"```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } /// This accumulator computes SUM incrementally @@ -279,7 +311,7 @@ impl Accumulator for SumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -339,7 +371,7 @@ impl Accumulator for SlidingSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -433,7 +465,6 @@ impl Accumulator for DistinctSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 3648ec0d13127..810247a2884a9 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,22 +18,25 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. -use std::{fmt::Debug, sync::Arc}; - use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; +use std::mem::{size_of, size_of_val}; +use std::sync::OnceLock; +use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, }; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, @@ -135,6 +138,26 @@ impl AggregateUDFImpl for VarianceSample { ) -> Result> { Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_sample_doc()) + } +} + +static VARIANCE_SAMPLE_DOC: OnceLock = OnceLock::new(); + +fn get_variance_sample_doc() -> &'static Documentation { + VARIANCE_SAMPLE_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical sample variance of a set of numbers.", + ) + .with_syntax_example("var(expression)") + .with_standard_argument("expression", Some("Numeric")) + .build() + .unwrap() + }) } pub struct VariancePopulation { @@ -222,6 +245,25 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_population_doc()) + } +} + +static VARIANCE_POPULATION_DOC: OnceLock = OnceLock::new(); + +fn get_variance_population_doc() -> &'static Documentation { + VARIANCE_POPULATION_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical population variance of a set of numbers.", + ) + .with_syntax_example("var_pop(expression)") + .with_standard_argument("expression", Some("Numeric")) + .build() + .unwrap() + }) } /// An accumulator to compute variance @@ -383,7 +425,7 @@ impl Accumulator for VarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn supports_retract_batch(&self) -> bool { @@ -488,7 +530,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -514,7 +556,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); @@ -565,8 +607,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } fn size(&self) -> usize { - self.m2s.capacity() * std::mem::size_of::() - + self.means.capacity() * std::mem::size_of::() - + self.counts.capacity() * std::mem::size_of::() + self.m2s.capacity() * size_of::() + + self.means.capacity() * size_of::() + + self.counts.capacity() * size_of::() } } diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 8f8d123bf5f9c..fe1d05199e80d 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -25,14 +25,17 @@ use arrow_buffer::BooleanBuffer; use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; use crate::utils::make_scalar_function; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!(ArrayHas, @@ -129,6 +132,43 @@ impl ScalarUDFImpl for ArrayHas { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_has_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if the array contains the element.", + ) + .with_syntax_example("array_has(array, element)") + .with_sql_example( + r#"```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } fn array_has_inner_for_scalar( @@ -289,6 +329,41 @@ impl ScalarUDFImpl for ArrayHasAll { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_all_doc()) + } +} + +fn get_array_has_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if all elements of sub-array exist in array.", + ) + .with_syntax_example("array_has_all(array, sub-array)") + .with_sql_example( + r#"```sql +> select array_has_all([1, 2, 3, 4], [2, 3]); ++--------------------------------------------+ +| array_has_all(List([1,2,3,4]), List([2,3])) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "sub-array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -335,6 +410,41 @@ impl ScalarUDFImpl for ArrayHasAny { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_any_doc()) + } +} + +fn get_array_has_any_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if any elements exist in both arrays.", + ) + .with_syntax_example("array_has_any(array, sub-array)") + .with_sql_example( + r#"```sql +> select array_has_any([1, 2, 3], [3, 4]); ++------------------------------------------+ +| array_has_any(List([1,2,3]), List([3,4])) | ++------------------------------------------+ +| true | ++------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "sub-array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Represents the type of comparison for array_has. diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index ea07ac381affd..b6661e0807f4e 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -26,12 +26,13 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( Cardinality, @@ -89,6 +90,39 @@ impl ScalarUDFImpl for Cardinality { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_cardinality_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cardinality_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the total number of elements in the array.", + ) + .with_syntax_example("cardinality(array)") + .with_sql_example( + r#"```sql +> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Cardinality SQL function diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index c52118d0a5e2b..1bdcf74aee2a7 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,7 +17,8 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. -use std::{any::Any, cmp::Ordering, sync::Arc}; +use std::sync::{Arc, OnceLock}; +use std::{any::Any, cmp::Ordering}; use arrow::array::{Capacities, MutableArrayData}; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; @@ -27,9 +28,10 @@ use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, }; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - type_coercion::binary::get_wider_type, ColumnarValue, ScalarUDFImpl, Signature, - Volatility, + type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl, + Signature, Volatility, }; use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; @@ -91,6 +93,43 @@ impl ScalarUDFImpl for ArrayAppend { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_append_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_append_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Appends an element to the end of an array.", + ) + .with_syntax_example("array_append(array, element)") + .with_sql_example( + r#"```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to append to the array.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -150,6 +189,41 @@ impl ScalarUDFImpl for ArrayPrepend { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_prepend_doc()) + } +} + +fn get_array_prepend_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Prepends an element to the beginning of an array.", + ) + .with_syntax_example("array_prepend(element, array)") + .with_sql_example( + r#"```sql +> select array_prepend(1, [2, 3, 4]); ++---------------------------------------+ +| array_prepend(Int64(1),List([2,3,4])) | ++---------------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------------+ +```"#, + ) + .with_argument( + "element", + "Element to prepend to the array.", + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -233,6 +307,41 @@ impl ScalarUDFImpl for ArrayConcat { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_concat_doc()) + } +} + +fn get_array_concat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Concatenates arrays.", + ) + .with_syntax_example("array_concat(array[, ..., array_n])") + .with_sql_example( + r#"```sql +> select array_concat([1, 2], [3, 4], [5, 6]); ++---------------------------------------------------+ +| array_concat(List([1,2]),List([3,4]),List([5,6])) | ++---------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++---------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression to concatenate. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array_n", + "Subsequent array column or literal array to concatenate.", + ) + .build() + .unwrap() + }) } /// Array_concat/Array_cat SQL function diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index d84fa0c19ee9a..7df0ed2b40bdb 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -29,8 +29,11 @@ use datafusion_common::{exec_err, plan_err, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use arrow_schema::Field; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::sync::Arc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayDims, @@ -85,6 +88,39 @@ impl ScalarUDFImpl for ArrayDims { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_dims_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_dims_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of the array's dimensions.", + ) + .with_syntax_example("array_dims(array)") + .with_sql_example( + r#"```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -137,6 +173,41 @@ impl ScalarUDFImpl for ArrayNdims { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_ndims_doc()) + } +} + +fn get_array_ndims_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the number of dimensions of the array.", + ) + .with_syntax_example("array_ndims(array, element)") + .with_sql_example( + r#"```sql +> select array_ndims([[1, 2, 3], [4, 5, 6]]); ++----------------------------------+ +| array_ndims(List([1,2,3,4,5,6])) | ++----------------------------------+ +| 2 | ++----------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Array element.", + ) + .build() + .unwrap() + }) } /// Array_dims SQL function diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index fa9394c73bcb0..4f890e4166e9f 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -31,9 +31,12 @@ use datafusion_common::cast::{ use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayDistance, @@ -100,6 +103,43 @@ impl ScalarUDFImpl for ArrayDistance { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_distance_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_distance_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the Euclidean distance between two input arrays of equal length.", + ) + .with_syntax_example("array_distance(array1, array2)") + .with_sql_example( + r#"```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } pub fn array_distance_inner(args: &[ArrayRef]) -> Result { @@ -207,7 +247,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - DataType::Float64 => Ok(as_float64_array(array)?.clone()), + Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index 36c82e92081d2..5d310eb23952e 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -23,9 +23,12 @@ use arrow_schema::DataType; use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; use datafusion_common::cast::as_generic_list_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayEmpty, @@ -77,6 +80,39 @@ impl ScalarUDFImpl for ArrayEmpty { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_empty_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_empty_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns 1 for an empty array or 0 for a non-empty array.", + ) + .with_syntax_example("empty(array)") + .with_sql_example( + r#"```sql +> select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Array_empty SQL function diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index 50ef20a7d4162..947d3c0182214 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -24,10 +24,13 @@ use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; use std::collections::HashSet; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayExcept, @@ -78,6 +81,49 @@ impl ScalarUDFImpl for ArrayExcept { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_except_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_except_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of the elements that appear in the first array but not in the second.", + ) + .with_syntax_example("array_except(array1, array2)") + .with_sql_example( + r#"```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Array_except SQL function diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 7dfc736b76d3e..275095832edb9 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -35,10 +35,13 @@ use datafusion_common::cast::as_list_array; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, }; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; @@ -147,6 +150,43 @@ impl ScalarUDFImpl for ArrayElement { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_element_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_element_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Extracts the element with the index n from the array.", + ) + .with_syntax_example("array_element(array, index)") + .with_sql_example( + r#"```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "index", + "Index to extract the element from the array.", + ) + .build() + .unwrap() + }) } /// array_element SQL function @@ -314,6 +354,49 @@ impl ScalarUDFImpl for ArraySlice { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_slice_doc()) + } +} + +fn get_array_slice_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns a slice of the array based on 1-indexed start and end positions.", + ) + .with_syntax_example("array_slice(array, begin, end)") + .with_sql_example( + r#"```sql +> select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "begin", + "Index of the first element. If negative, it counts backward from the end of the array.", + ) + .with_argument( + "end", + "Index of the last element. If negative, it counts backward from the end of the array.", + ) + .with_argument( + "stride", + "Stride of the array slice. The default is 1.", + ) + .build() + .unwrap() + }) } /// array_slice SQL function @@ -580,6 +663,37 @@ impl ScalarUDFImpl for ArrayPopFront { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_pop_front_doc()) + } +} + +fn get_array_pop_front_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array without the first element.", + ) + .with_syntax_example("array_pop_front(array)") + .with_sql_example( + r#"```sql +> select array_pop_front([1, 2, 3]); ++-------------------------------+ +| array_pop_front(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_pop_front SQL function @@ -655,6 +769,37 @@ impl ScalarUDFImpl for ArrayPopBack { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_pop_back_doc()) + } +} + +fn get_array_pop_back_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array without the last element.", + ) + .with_syntax_example("array_pop_back(array)") + .with_sql_example( + r#"```sql +> select array_pop_back([1, 2, 3]); ++-------------------------------+ +| array_pop_back(List([1,2,3])) | ++-------------------------------+ +| [1, 2] | ++-------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_pop_back SQL function @@ -738,6 +883,37 @@ impl ScalarUDFImpl for ArrayAnyValue { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_any_value_doc()) + } +} + +fn get_array_any_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the first non-null element in the array.", + ) + .with_syntax_example("array_any_value(array)") + .with_sql_example( + r#"```sql +> select array_any_value([NULL, 1, 2, 3]); ++-------------------------------+ +| array_any_value(List([NULL,1,2,3])) | ++-------------------------------------+ +| 1 | ++-------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } fn array_any_value_inner(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index b04c35667226c..4fe631517b09d 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -26,9 +26,12 @@ use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( Flatten, @@ -95,6 +98,38 @@ impl ScalarUDFImpl for Flatten { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_flatten_doc()) + } +} +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_flatten_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.", + ) + .with_syntax_example("flatten(array)") + .with_sql_example( + r#"```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Flatten SQL function diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 5d9ccd2901cfa..3e039f286421a 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -27,9 +27,12 @@ use core::any::type_name; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayLength, @@ -81,6 +84,43 @@ impl ScalarUDFImpl for ArrayLength { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the length of the array dimension.", + ) + .with_syntax_example("array_length(array, dimension)") + .with_sql_example( + r#"```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "dimension", + "Array dimension.", + ) + .build() + .unwrap() + }) } /// Array_length SQL function diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 79858041d3cac..c2c6f24948b8f 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -17,7 +17,9 @@ //! [`ScalarUDFImpl`] definitions for `make_array` function. -use std::{any::Any, sync::Arc}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; +use std::vec; use arrow::array::{ArrayData, Capacities, MutableArrayData}; use arrow_array::{ @@ -26,11 +28,15 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::internal_err; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; -use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_expr::binary::{ + try_type_union_resolution_with_struct, type_union_resolution, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::TypeSignature; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::make_scalar_function; @@ -82,19 +88,12 @@ impl ScalarUDFImpl for MakeArray { match arg_types.len() { 0 => Ok(empty_array_type()), _ => { - let mut expr_type = DataType::Null; - for arg_type in arg_types { - if !arg_type.equals_datatype(&DataType::Null) { - expr_type = arg_type.clone(); - break; - } - } - - if expr_type.is_null() { - expr_type = DataType::Int64; - } - - Ok(List(Arc::new(Field::new("item", expr_type, true)))) + // At this point, all the type in array should be coerced to the same one + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].to_owned(), + true, + )))) } } } @@ -112,28 +111,70 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let new_type = arg_types.iter().skip(1).try_fold( - arg_types.first().unwrap().clone(), - |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") - } - }, - )?; - Ok(vec![new_type; arg_types.len()]) + let mut errors = vec![]; + match try_type_union_resolution_with_struct(arg_types) { + Ok(r) => return Ok(r), + Err(e) => { + errors.push(e); + } + } + + if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move FixedSizeList to List in type_union_resolution + if let DataType::FixedSizeList(field, _) = new_type { + Ok(vec![List(field); arg_types.len()]) + } else if new_type.is_null() { + Ok(vec![DataType::Int64; arg_types.len()]) + } else { + Ok(vec![new_type; arg_types.len()]) + } + } else { + plan_err!( + "Fail to find the valid type between {:?} for {}, errors are {:?}", + arg_types, + self.name(), + errors + ) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_make_array_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_make_array_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array using the specified input expressions.", + ) + .with_syntax_example("make_array(expression1[, ..., expression_n])") + .with_sql_example( + r#"```sql +> select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +```"#, + ) + .with_argument( + "expression_n", + "Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators.", + ) + .build() + .unwrap() + }) +} + // Empty array is a special case that is useful for many other array functions pub(super) fn empty_array_type() -> DataType { - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))) + List(Arc::new(Field::new("item", DataType::Int64, true))) } /// `make_array_inner` is the implementation of the `make_array` function. diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 29afe4a7f3bea..d7dce3bacbe1e 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::collections::{HashSet, VecDeque}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayData; use arrow_array::{Array, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; @@ -27,7 +27,10 @@ use arrow_schema::{DataType, Field, SchemaBuilder}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; use crate::make_array::make_array; @@ -238,7 +241,69 @@ impl ScalarUDFImpl for MapFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_map_batch(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_doc()) + } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns an Arrow map with the specified key-value pairs.\n\n\ + The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null." + ) + .with_syntax_example( + "map(key, value)\nmap(key: value)\nmake_map(['key1', 'key2'], ['value1', 'value2'])" + ) + .with_sql_example( + r#"```sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ```"# + ) + .with_argument( + "key", + "For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null." + ) + .with_argument( + "value", + "For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of values to be mapped to the corresponding keys." + ) + .build() + .unwrap() + }) +} + fn get_element_type(data_type: &DataType) -> Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 9f0c4ad29c60e..d2bb6595fe76e 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -26,9 +26,12 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::{cast::as_map_array, exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use std::vec; use crate::utils::{get_map_entry_field, make_scalar_function}; @@ -101,6 +104,48 @@ impl ScalarUDFImpl for MapExtract { field.first().unwrap().data_type().clone(), ]) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_extract_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_extract_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list containing the value for the given key or an empty list if the key is not present in the map.", + ) + .with_syntax_example("map_extract(map, key)") + .with_sql_example( + r#"```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators.", + ) + .with_argument( + "key", + "Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed.", + ) + .build() + .unwrap() + }) } fn general_map_extract_inner( diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 0b1cebb27c866..03e381e372f64 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -21,12 +21,13 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( MapKeysFunc, @@ -65,7 +66,7 @@ impl ScalarUDFImpl for MapKeysFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_keys expects single argument"); } @@ -78,9 +79,43 @@ impl ScalarUDFImpl for MapKeysFunc { )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(map_keys_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_keys_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_keys_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all keys in the map." + ) + .with_syntax_example("map_keys(map)") + .with_sql_example( + r#"```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) } fn map_keys_inner(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 58c0d74eed5ff..dc7d9c9db8eec 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -21,12 +21,13 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( MapValuesFunc, @@ -65,7 +66,7 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_values expects single argument"); } @@ -78,9 +79,43 @@ impl ScalarUDFImpl for MapValuesFunc { )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(map_values_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_values_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_values_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all values in the map." + ) + .with_syntax_example("map_values(map)") + .with_sql_example( + r#"```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) } fn map_values_inner(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 4cd8faa3ca98c..9ae2fa781d87e 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -34,6 +34,7 @@ use crate::{ make_array::make_array, }; +#[derive(Debug)] pub struct NestedFunctionPlanner; impl ExprPlanner for NestedFunctionPlanner { @@ -130,6 +131,7 @@ impl ExprPlanner for NestedFunctionPlanner { } } +#[derive(Debug)] pub struct FieldAccessPlanner; impl ExprPlanner for FieldAccessPlanner { diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index a48332ceb0b30..adb45141601d6 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -19,9 +19,12 @@ use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow_array::types::UInt64Type; use arrow_array::{ @@ -86,6 +89,53 @@ impl ScalarUDFImpl for ArrayPosition { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_position_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_position_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the position of the first occurrence of the specified element in the array.", + ) + .with_syntax_example("array_position(array, element)\narray_position(array, element, index)") + .with_sql_example( + r#"```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to search for position in the array.", + ) + .with_argument( + "index", + "Index at which to start searching.", + ) + .build() + .unwrap() + }) } /// Array_position SQL function @@ -210,6 +260,41 @@ impl ScalarUDFImpl for ArrayPositions { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_positions_doc()) + } +} + +fn get_array_positions_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Searches for an element in the array, returns all occurrences.", + ) + .with_syntax_example("array_positions(array, element)") + .with_sql_example( + r#"```sql +> select array_positions([1, 2, 2, 3, 1, 4], 2); ++-----------------------------------------------+ +| array_positions(List([1,2,2,3,1,4]),Int64(2)) | ++-----------------------------------------------+ +| [2, 3] | ++-----------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to search for positions in the array.", + ) + .build() + .unwrap() + }) } /// Array_positions SQL function diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index b3d8010cb6683..ddc56b1e4ee88 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -37,13 +37,16 @@ use datafusion_common::cast::{ use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, Result, }; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use itertools::Itertools; use std::any::Any; use std::cmp::Ordering; use std::iter::from_fn; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( Range, @@ -133,6 +136,54 @@ impl ScalarUDFImpl for Range { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_range_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_range_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0.", + ) + .with_syntax_example("range(start, stop, step)") + .with_sql_example( + r#"```sql +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ + +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ +```"#, + ) + .with_argument( + "start", + "Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported.", + ) + .with_argument( + "end", + "End of the range (not included). Type must be the same as start.", + ) + .with_argument( + "step", + "Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -226,6 +277,47 @@ impl ScalarUDFImpl for GenSeries { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_generate_series_doc()) + } +} + +static GENERATE_SERIES_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_generate_series_doc() -> &'static Documentation { + GENERATE_SERIES_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Similar to the range function, but it includes the upper bound.", + ) + .with_syntax_example("generate_series(start, stop, step)") + .with_sql_example( + r#"```sql +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ +```"#, + ) + .with_argument( + "start", + "start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported.", + ) + .with_argument( + "end", + "end of the series (included). Type must be the same as start.", + ) + .with_argument( + "step", + "increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges.", + ) + .build() + .unwrap() + }) } /// Generates an array of integers from start to stop with a given step. diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 0b7cfc283c06f..dc1ed4833c67c 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -27,9 +27,12 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayRemove, @@ -78,6 +81,43 @@ impl ScalarUDFImpl for ArrayRemove { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_remove_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes the first element from the array equal to the given value.", + ) + .with_syntax_example("array_remove(array, element)") + .with_sql_example( + r#"```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -127,6 +167,45 @@ impl ScalarUDFImpl for ArrayRemoveN { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_n_doc()) + } +} + +fn get_array_remove_n_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes the first `max` elements from the array equal to the given value.", + ) + .with_syntax_example("array_remove_n(array, element, max)") + .with_sql_example( + r#"```sql +> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); ++---------------------------------------------------------+ +| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | ++---------------------------------------------------------+ +| [1, 3, 2, 1, 4] | ++---------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .with_argument( + "max", + "Number of first occurrences to remove.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -176,6 +255,41 @@ impl ScalarUDFImpl for ArrayRemoveAll { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_all_doc()) + } +} + +fn get_array_remove_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes all elements from the array equal to the given value.", + ) + .with_syntax_example("array_remove_all(array, element)") + .with_sql_example( + r#"```sql +> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); ++--------------------------------------------------+ +| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | ++--------------------------------------------------+ +| [1, 3, 1, 4] | ++--------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .build() + .unwrap() + }) } /// Array_remove SQL function diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index 7ed913da3f2a0..55584c143a549 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -29,9 +29,12 @@ use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayRepeat, @@ -83,6 +86,49 @@ impl ScalarUDFImpl for ArrayRepeat { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_repeat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_repeat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array containing element `count` times.", + ) + .with_syntax_example("array_repeat(element, count)") + .with_sql_example( + r#"```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +```"#, + ) + .with_argument( + "element", + "Element expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "count", + "Value of how many times to repeat the element.", + ) + .build() + .unwrap() + }) } /// Array_repeat SQL function diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 46a2e078aa4cd..1d0a1d1f28152 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -27,13 +27,16 @@ use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::compare_element_to_list; use crate::utils::make_scalar_function; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!(ArrayReplace, @@ -94,6 +97,47 @@ impl ScalarUDFImpl for ArrayReplace { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces the first occurrence of the specified element with another specified element.", + ) + .with_syntax_example("array_replace(array, from, to)") + .with_sql_example( + r#"```sql +> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); ++--------------------------------------------------------+ +| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++--------------------------------------------------------+ +| [1, 5, 2, 3, 2, 1, 4] | ++--------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -135,6 +179,49 @@ impl ScalarUDFImpl for ArrayReplaceN { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_n_doc()) + } +} + +fn get_array_replace_n_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces the first `max` occurrences of the specified element with another specified element.", + ) + .with_syntax_example("array_replace_n(array, from, to, max)") + .with_sql_example( + r#"```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .with_argument( + "max", + "Number of first occurrences to replace.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -176,6 +263,45 @@ impl ScalarUDFImpl for ArrayReplaceAll { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_all_doc()) + } +} + +fn get_array_replace_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces all occurrences of the specified element with another specified element.", + ) + .with_syntax_example("array_replace_all(array, from, to)") + .with_sql_example( + r#"```sql +> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); ++------------------------------------------------------------+ +| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++------------------------------------------------------------+ +| [1, 5, 5, 3, 5, 1, 4] | ++------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .build() + .unwrap() + }) } /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index 83c545a26eb24..294076a52b526 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -25,9 +25,12 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayResize, @@ -82,6 +85,47 @@ impl ScalarUDFImpl for ArrayResize { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_resize_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_resize_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.", + ) + .with_syntax_example("array_resize(array, size, value)") + .with_sql_example( + r#"```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "size", + "New size of given array.", + ) + .with_argument( + "value", + "Defines new elements' value or empty if value is not set.", + ) + .build() + .unwrap() + }) } /// array_resize SQL function diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index 581caf5daf2b8..1ecf7f8484684 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -25,9 +25,12 @@ use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArrayReverse, @@ -76,6 +79,39 @@ impl ScalarUDFImpl for ArrayReverse { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_reverse_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_reverse_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array with the order of the elements reversed.", + ) + .with_syntax_example("array_reverse(array)") + .with_sql_example( + r#"```sql +> select array_reverse([1, 2, 3, 4]); ++------------------------------------------------------------+ +| array_reverse(List([1, 2, 3, 4])) | ++------------------------------------------------------------+ +| [4, 3, 2, 1] | ++------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_reverse SQL function diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 1de9c264ddc2c..ce8d248319fe5 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -27,12 +27,15 @@ use arrow::row::{RowConverter, SortField}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use itertools::Itertools; use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( @@ -102,6 +105,49 @@ impl ScalarUDFImpl for ArrayUnion { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_union_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_union_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + ) + .with_syntax_example("array_union(array1, array2)") + .with_sql_example( + r#"```sql +> select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +> select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -147,6 +193,47 @@ impl ScalarUDFImpl for ArrayIntersect { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_intersect_doc()) + } +} + +fn get_array_intersect_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of elements in the intersection of array1 and array2.", + ) + .with_syntax_example("array_intersect(array1, array2)") + .with_sql_example( + r#"```sql +> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [3, 4] | ++----------------------------------------------------+ +> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -202,6 +289,37 @@ impl ScalarUDFImpl for ArrayDistinct { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_distinct_doc()) + } +} + +fn get_array_distinct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns distinct values from the array after removing duplicates.", + ) + .with_syntax_example("array_distinct(array)") + .with_sql_example( + r#"```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_distinct SQL function diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 9c1ae507636c9..b29c187f0679c 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -25,9 +25,12 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( ArraySort, @@ -90,6 +93,47 @@ impl ScalarUDFImpl for ArraySort { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_sort_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_sort_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Sort array.", + ) + .with_syntax_example("array_sort(array, desc, nulls_first)") + .with_sql_example( + r#"```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "desc", + "Whether to sort in descending order(`ASC` or `DESC`).", + ) + .with_argument( + "nulls_first", + "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`).", + ) + .build() + .unwrap() + }) } /// Array_sort SQL function diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 2dc0a55e69519..30f3845215fc8 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -39,8 +39,11 @@ use datafusion_common::cast::{ as_generic_string_array, as_large_list_array, as_list_array, as_string_array, }; use datafusion_common::exec_err; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::sync::Arc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; macro_rules! to_string { ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ @@ -159,6 +162,43 @@ impl ScalarUDFImpl for ArrayToString { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_to_string_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_to_string_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Converts each element to its text representation.", + ) + .with_syntax_example("array_to_string(array, delimiter)") + .with_sql_example( + r#"```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "delimiter", + "Array element separator.", + ) + .build() + .unwrap() + }) } make_udf_expr_and_func!( @@ -228,6 +268,51 @@ impl ScalarUDFImpl for StringToArray { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_string_to_array_doc()) + } +} + +fn get_string_to_array_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Splits a string into an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL.", + ) + .with_syntax_example("string_to_array(str, delimiter[, null_str])") + .with_sql_example( + r#"```sql +> select string_to_array('abc##def', '##'); ++-----------------------------------+ +| string_to_array(Utf8('abc##def')) | ++-----------------------------------+ +| ['abc', 'def'] | ++-----------------------------------+ +> select string_to_array('abc def', ' ', 'def'); ++---------------------------------------------+ +| string_to_array(Utf8('abc def'), Utf8(' '), Utf8('def')) | ++---------------------------------------------+ +| ['abc', NULL] | ++---------------------------------------------+ +```"#, + ) + .with_argument( + "str", + "String expression to split.", + ) + .with_argument( + "delimiter", + "Delimiter string to split on.", + ) + .with_argument( + "null_str", + "Substring values to be replaced with `NULL`.", + ) + .build() + .unwrap() + }) } /// Array_to_string SQL function diff --git a/datafusion/functions-window-common/Cargo.toml b/datafusion/functions-window-common/Cargo.toml index 98b6f8c6dba5f..b5df212b7d2ad 100644 --- a/datafusion/functions-window-common/Cargo.toml +++ b/datafusion/functions-window-common/Cargo.toml @@ -39,3 +39,4 @@ path = "src/lib.rs" [dependencies] datafusion-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 0000000000000..1d99fe7acf152 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 2e4bcbbc83b9a..da8d096da5621 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -18,4 +18,6 @@ //! Common user-defined window functionality for [DataFusion] //! //! [DataFusion]: +pub mod expr; pub mod field; +pub mod partition; diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs new file mode 100644 index 0000000000000..64786d2fe7c70 --- /dev/null +++ b/datafusion/functions-window-common/src/partition.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to created user-defined window function state +/// during physical execution. +#[derive(Debug, Default)] +pub struct PartitionEvaluatorArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], + /// Set to `true` if the user-defined window function is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is specified. + ignore_nulls: bool, +} + +impl<'a> PartitionEvaluatorArgs<'a> { + /// Create an instance of [`PartitionEvaluatorArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// * `is_reversed` - Set to `true` if and only if the user-defined + /// window function is reversible and is reversed. + /// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is + /// specified. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + is_reversed: bool, + ignore_nulls: bool, + ) -> Self { + Self { + input_exprs, + input_types, + is_reversed, + ignore_nulls, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } + + /// Returns `true` when the user-defined window function is + /// reversed, otherwise returns `false`. + pub fn is_reversed(&self) -> bool { + self.is_reversed + } + + /// Returns `true` when `IGNORE NULLS` is specified, otherwise + /// returns `false`. + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } +} diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 8dcec6bc964b4..262c21fcec65d 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -41,8 +41,10 @@ path = "src/lib.rs" datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } +paste = "1.0.15" [dev-dependencies] arrow = { workspace = true } diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs new file mode 100644 index 0000000000000..9e30c672fee52 --- /dev/null +++ b/datafusion/functions-window/src/cume_dist.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `cume_dist` window function implementation + +use datafusion_common::arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::Result; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, OnceLock}; + +define_udwf_and_expr!( + CumeDist, + cume_dist, + "Calculates the cumulative distribution of a value in a group of values." +); + +/// CumeDist calculates the cume_dist in the window function with order by +#[derive(Debug)] +pub struct CumeDist { + signature: Signature, +} + +impl CumeDist { + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for CumeDist { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for CumeDist { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "cume_dist" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_cume_dist_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cume_dist_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", + ) + .with_syntax_example("cume_dist()") + .build() + .unwrap() + }) +} + +#[derive(Debug, Default)] +pub(crate) struct CumeDistEvaluator; + +impl PartitionEvaluator for CumeDistEvaluator { + /// Computes the cumulative distribution for all rows in the partition + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let scalar = num_rows as f64; + let result = Float64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(0_u64, |acc, range| { + let len = range.end - range.start; + *acc += len as u64; + let value: f64 = (*acc as f64) / scalar; + let result = iter::repeat(value).take(len); + Some(result) + }) + .flatten(), + ); + Ok(Arc::new(result)) + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_float64_array; + + fn test_f64_result( + num_rows: usize, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let evaluator = CumeDistEvaluator; + let result = evaluator.evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; + let result = result.values().to_vec(); + assert_eq!(expected, result); + Ok(()) + } + + #[test] + #[allow(clippy::single_range_in_vec_init)] + fn test_cume_dist() -> Result<()> { + test_f64_result(0, vec![], vec![])?; + + test_f64_result(1, vec![0..1], vec![1.0])?; + + test_f64_result(2, vec![0..2], vec![1.0, 1.0])?; + + test_f64_result(4, vec![0..2, 2..4], vec![0.5, 0.5, 1.0, 1.0])?; + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs similarity index 54% rename from datafusion/physical-expr/src/window/lead_lag.rs rename to datafusion/functions-window/src/lead_lag.rs index 1656b7c3033a4..bbe50cbbdc8af 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -15,125 +15,327 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +//! `lead` and `lag` window function implementations + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_expr::{ + Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, + Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +#[derive(Debug)] +enum WindowShiftKind { + Lag, + Lead, +} + +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } + } + + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value when computing `lead()`. + fn shift_offset(&self, value: Option) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } + } +} /// window shift expression #[derive(Debug)] pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc, - default_value: ScalarValue, - ignore_nulls: bool, + signature: Signature, + kind: WindowShiftKind, } impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } } - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) } -} -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) } } -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, - } +static LAG_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lag_doc() -> &'static Documentation { + LAG_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows before the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lag(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows back \ + the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) +} + +static LEAD_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lead_doc() -> &'static Documentation { + LEAD_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows after the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lead(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows \ + forward the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) } -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + fn name(&self) -> &str { + self.kind.name() } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() } - fn create_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let default_value = parse_default_value( + partition_evaluator_args.input_exprs(), + partition_evaluator_args.input_types(), + )?; + Ok(Box::new(WindowShiftEvaluator { - shift_offset: self.shift_offset, - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, + shift_offset, + default_value, + ignore_nulls: partition_evaluator_args.ignore_nulls(), non_null_offsets: VecDeque::new(), })) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.clone(), - data_type: self.data_type.clone(), - shift_offset: -self.shift_offset, - expr: Arc::clone(&self.expr), - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - })) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) + } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()), + WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + match self.kind { + WindowShiftKind::Lag => Some(get_lag_doc()), + WindowShiftKind::Lead => Some(get_lead_doc()), + } } } +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); + } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) +} + #[derive(Debug)] -pub(crate) struct WindowShiftEvaluator { +struct WindowShiftEvaluator { shift_offset: i64, default_value: ScalarValue, ignore_nulls: bool, @@ -205,7 +407,7 @@ fn shift_with_default_value( offset: i64, default_value: &ScalarValue, ) -> Result { - use arrow::compute::concat; + use datafusion_common::arrow::compute::concat; let value_len = array.len() as i64; if offset == 0 { @@ -402,19 +604,22 @@ impl PartitionEvaluator for WindowShiftEvaluator { #[cfg(test)] mod tests { use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; + use arrow::array::*; use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + fn test_i32_result( + expr: WindowShift, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let values = expr.evaluate_args(&batch)?; + let num_rows = values.len(); let result = expr - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; + .partition_evaluator(partition_evaluator_args)? + .evaluate_all(&values, num_rows)?; let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) @@ -466,16 +671,12 @@ mod tests { } #[test] - fn lead_lag_window_shift() -> Result<()> { + fn test_lead_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( - lead( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lead(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ Some(-2), Some(3), @@ -488,17 +689,16 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ None, Some(1), @@ -511,17 +711,24 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_with_default() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let shift_offset = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100)))) + as Arc; + + let input_exprs = &[expr, shift_offset, default_value]; + let input_types: &[DataType] = + &[DataType::Int32, DataType::Int32, DataType::Int32]; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Int32(Some(100)), - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), [ Some(100), Some(1), @@ -534,7 +741,6 @@ mod tests { ] .iter() .collect::(), - )?; - Ok(()) + ) } } diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 790a500f1f3f4..ff8542838df9f 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -29,16 +29,38 @@ use log::debug; use datafusion_expr::registry::FunctionRegistry; use datafusion_expr::WindowUDF; +#[macro_use] +pub mod macros; + +pub mod cume_dist; +pub mod lead_lag; +pub mod ntile; +pub mod rank; pub mod row_number; +mod utils; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::cume_dist::cume_dist; + pub use super::lead_lag::lag; + pub use super::lead_lag::lead; + pub use super::ntile::ntile; + pub use super::rank::{dense_rank, percent_rank, rank}; pub use super::row_number::row_number; } /// Returns all default window functions pub fn all_default_window_functions() -> Vec> { - vec![row_number::row_number_udwf()] + vec![ + cume_dist::cume_dist_udwf(), + row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), + rank::rank_udwf(), + rank::dense_rank_udwf(), + rank::percent_rank_udwf(), + ntile::ntile_udwf(), + ] } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all( diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs new file mode 100644 index 0000000000000..2905ccf4c2048 --- /dev/null +++ b/datafusion/functions-window/src/macros.rs @@ -0,0 +1,689 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convenience macros for defining a user-defined window function +//! and associated expression API (fluent style). +//! +//! See [`define_udwf_and_expr!`] for usage examples. +//! +//! [`define_udwf_and_expr!`]: crate::define_udwf_and_expr! + +/// Lazily initializes a user-defined window function exactly once +/// when called concurrently. Repeated calls return a reference to the +/// same instance. +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`](datafusion_expr::Signature) +/// of the user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// # Example +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::get_or_init_udwf; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// Defines the `simple_udwf()` user-defined window function. +/// get_or_init_udwf!( +/// SimpleUDWF, +/// simple, +/// "Simple user-defined window function doc comment." +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +#[macro_export] +macro_rules! get_or_init_udwf { + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $UDWF::default); + }; + + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + paste::paste! { + #[doc = concat!(" Singleton instance of [`", stringify!($OUT_FN_NAME), "`], ensures the user-defined")] + #[doc = concat!(" window function is only created once.")] + #[allow(non_upper_case_globals)] + static []: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + #[doc = concat!(" Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for [`", stringify!($OUT_FN_NAME), "`].")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn [<$OUT_FN_NAME _udwf>]() -> std::sync::Arc { + [] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::WindowUDF::from($CTOR())) + }) + .clone() + } + } + }; +} + +/// Create a [`WindowFunction`] expression that exposes a fluent API +/// which you can use to build more complex expressions. +/// +/// [`WindowFunction`]: datafusion_expr::Expr::WindowFunction +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Example +/// +/// 1. With Zero Parameters +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// +/// # get_or_init_udwf!( +/// # RowNumber, +/// # row_number, +/// # "Returns a unique row number for each row in window partition beginning at 1." +/// # ); +/// /// Creates `row_number()` API which has zero parameters: +/// /// +/// /// ``` +/// /// /// Returns a unique row number for each row in window partition +/// /// /// beginning at 1. +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1." +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl Default for RowNumber { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// 2. With Multiple Parameters +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// # get_or_init_udwf!(Lead, lead, "user-defined window function"); +/// # +/// /// Creates `lead(expr, offset, default)` with 3 parameters: +/// /// +/// /// ``` +/// /// /// Returns a value evaluated at the row that is offset rows +/// /// /// after the current row within the partition. +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], +/// "Returns a value evaluated at the row that is offset rows after the current row within the partition." +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! create_udwf_expr { + // zero arguments + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME() -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]().call(vec![]) + } + } + }; + + // 1 or more arguments + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME( + $($PARAM: datafusion_expr::Expr),+ + ) -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]() + .call(vec![$($PARAM),+]) + } + } + }; +} + +/// Defines a user-defined window function. +/// +/// Combines [`get_or_init_udwf!`] and [`create_udwf_expr!`] into a +/// single macro for convenience. +/// +/// # Arguments +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Usage +/// +/// ## Expression API With Zero parameters +/// 1. Uses default constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::{define_udwf_and_expr, get_or_init_udwf, create_udwf_expr}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `simple_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn simple() -> datafusion_expr::Expr { +/// /// simple_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// SimpleUDWF, +/// simple, +/// "a simple user-defined window function" +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +/// +/// 2. Uses a custom constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `row_number_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1.", +/// RowNumber::new // <-- custom constructor +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl RowNumber { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// ## Expression API With Multiple Parameters +/// 3. Uses default constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function" +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +/// 4. Uses custom constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function", +/// Lead::new // <- Custom constructor +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Lead { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! define_udwf_and_expr { + // Defines UDWF with default constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF with default constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; +} diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs new file mode 100644 index 0000000000000..b0a7241f24cd4 --- /dev/null +++ b/datafusion/functions-window/src/ntile.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `ntile` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::sync::{Arc, OnceLock}; + +use crate::utils::{ + get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, +}; +use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; +use datafusion_common::arrow::datatypes::{DataType, Field}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +get_or_init_udwf!( + Ntile, + ntile, + "integer ranging from 1 to the argument value, dividing the partition as equally as possible" +); + +pub fn ntile(arg: Expr) -> Expr { + ntile_udwf().call(vec![arg]) +} + +#[derive(Debug)] +pub struct Ntile { + signature: Signature, +} + +impl Ntile { + /// Create a new `ntile` function + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for Ntile { + fn default() -> Self { + Self::new() + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ntile_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", + ) + .with_syntax_example("ntile(expression)") + .with_argument("expression","An integer describing the number groups the partition should be split into") + .build() + .unwrap() + }) +} + +impl WindowUDFImpl for Ntile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ntile" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if scalar_n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if scalar_n.is_unsigned() { + let n = get_unsigned_integer(scalar_n)?; + Ok(Box::new(NtileEvaluator { n })) + } else { + let n: i64 = get_signed_integer(scalar_n)?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Ok(Box::new(NtileEvaluator { n: n as u64 })) + } + } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let nullable = false; + + Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ntile_doc()) + } +} + +#[derive(Debug)] +struct NtileEvaluator { + n: u64, +} + +impl PartitionEvaluator for NtileEvaluator { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { + let num_rows = num_rows as u64; + let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); + for i in 0..num_rows { + let res = i * n / num_rows; + vec.push(res + 1) + } + Ok(Arc::new(UInt64Array::from(vec))) + } +} diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/functions-window/src/rank.rs similarity index 55% rename from datafusion/physical-expr/src/window/rank.rs rename to datafusion/functions-window/src/rank.rs index fa3d4e487f14f..06c3f49055a51 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -15,40 +15,83 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::window_expr::RankState; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::ArrayRef; -use arrow::array::{Float64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +//! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, +//! which can be evaluated at runtime during query execution. use std::any::Any; +use std::fmt::Debug; use std::iter; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +use crate::define_udwf_and_expr; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::{Float64Array, UInt64Array}; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + Rank, + rank, + "Returns rank of the current row with gaps. Same as `row_number` of its first peer", + Rank::basic +); + +define_udwf_and_expr!( + DenseRank, + dense_rank, + "Returns rank of the current row without gaps. This function counts peer groups", + Rank::dense_rank +); + +define_udwf_and_expr!( + PercentRank, + percent_rank, + "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)", + Rank::percent_rank +); /// Rank calculates the rank in the window function with order by #[derive(Debug)] pub struct Rank { name: String, + signature: Signature, rank_type: RankType, - /// Output data type - data_type: DataType, } impl Rank { - /// Get rank_type of the rank in window function with order by - pub fn get_type(&self) -> RankType { - self.rank_type + /// Create a new `rank` function with the specified name and rank type + pub fn new(name: String, rank_type: RankType) -> Self { + Self { + name, + signature: Signature::any(0, Volatility::Immutable), + rank_type, + } + } + + /// Create a `rank` window function + pub fn basic() -> Self { + Rank::new("rank".to_string(), RankType::Basic) + } + + /// Create a `dense_rank` window function + pub fn dense_rank() -> Self { + Rank::new("dense_rank".to_string(), RankType::Dense) + } + + /// Create a `percent_rank` window function + pub fn percent_rank() -> Self { + Rank::new("percent_rank".to_string(), RankType::Percent) } } @@ -59,74 +102,121 @@ pub enum RankType { Percent, } -/// Create a rank window function -pub fn rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Basic, - data_type: data_type.clone(), - } +static RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rank_doc() -> &'static Documentation { + RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the rank of the current row within its partition, allowing \ + gaps between ranks. This function provides a ranking similar to `row_number`, but \ + skips ranks for identical values.", + ) + .with_syntax_example("rank()") + .build() + .unwrap() + }) } -/// Create a dense rank window function -pub fn dense_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Dense, - data_type: data_type.clone(), - } +static DENSE_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_dense_rank_doc() -> &'static Documentation { + DENSE_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the rank of the current row without gaps. This function ranks \ + rows in a dense manner, meaning consecutive ranks are assigned even for identical \ + values.", + ) + .with_syntax_example("dense_rank()") + .build() + .unwrap() + }) } -/// Create a percent rank window function -pub fn percent_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Percent, - data_type: data_type.clone(), - } +static PERCENT_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_percent_rank_doc() -> &'static Documentation { + PERCENT_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the percentage rank of the current row within its partition. \ + The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", + ) + .with_syntax_example("percent_rank()") + .build() + .unwrap() + }) } -impl BuiltInWindowFunctionExpr for Rank { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for Rank { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - fn name(&self) -> &str { &self.name } - fn create_evaluator(&self) -> Result> { + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(RankEvaluator { state: RankState::default(), rank_type: self.rank_type, })) } - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in RANK window function (in all modes) introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = match self.rank_type { + RankType::Basic | RankType::Dense => DataType::UInt64, + RankType::Percent => DataType::Float64, + }; + + let nullable = false; + Ok(Field::new(field_args.name(), return_type, nullable)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, }) } + + fn documentation(&self) -> Option<&Documentation> { + match self.rank_type { + RankType::Basic => Some(get_rank_doc()), + RankType::Dense => Some(get_dense_rank_doc()), + RankType::Percent => Some(get_percent_rank_doc()), + } + } +} + +/// State for the RANK(rank) built-in window function. +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Option>, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Keep the number of entries in current rank + pub current_group_count: usize, + /// Rank number kept from the start + pub n_rank: usize, } +/// State for the `rank` built-in window function. #[derive(Debug)] -pub(crate) struct RankEvaluator { +struct RankEvaluator { state: RankState, rank_type: RankType, } @@ -136,7 +226,6 @@ impl PartitionEvaluator for RankEvaluator { matches!(self.rank_type, RankType::Basic | RankType::Dense) } - /// Evaluates the window function inside the given range. fn evaluate( &mut self, values: &[ArrayRef], @@ -163,6 +252,7 @@ impl PartitionEvaluator for RankEvaluator { // data is still in the same rank self.state.current_group_count += 1; } + match self.rank_type { RankType::Basic => Ok(ScalarValue::UInt64(Some( self.state.last_rank_boundary as u64 + 1, @@ -179,8 +269,19 @@ impl PartitionEvaluator for RankEvaluator { num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - // see https://www.postgresql.org/docs/current/functions-window.html let result: ArrayRef = match self.rank_type { + RankType::Basic => Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(1_u64, |acc, range| { + let len = range.end - range.start; + let result = iter::repeat(*acc).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )), + RankType::Dense => Arc::new(UInt64Array::from_iter_values( ranks_in_partition .iter() @@ -190,9 +291,10 @@ impl PartitionEvaluator for RankEvaluator { iter::repeat(rank).take(len) }), )), + RankType::Percent => { - // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. let denominator = num_rows as f64; + Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -206,18 +308,8 @@ impl PartitionEvaluator for RankEvaluator { .flatten(), )) } - RankType::Basic => Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(1_u64, |acc, range| { - let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )), }; + Ok(result) } @@ -244,53 +336,57 @@ mod tests { test_i32_result(expr, vec![0..8], expected) } - fn test_f64_result( + fn test_i32_result( expr: &Rank, - num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; + .partition_evaluator(args)? + .evaluate_all_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } - fn test_i32_result( + fn test_f64_result( expr: &Rank, + num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; - let result = as_uint64_array(&result)?; + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } #[test] - fn test_dense_rank() -> Result<()> { - let r = dense_rank("arr".into(), &DataType::UInt64); + fn test_rank() -> Result<()> { + let r = Rank::basic(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; + test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; Ok(()) } #[test] - fn test_rank() -> Result<()> { - let r = rank("arr".into(), &DataType::UInt64); + fn test_dense_rank() -> Result<()> { + let r = Rank::dense_rank(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; + test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; Ok(()) } #[test] #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { - let r = percent_rank("arr".into(), &DataType::Float64); + let r = Rank::percent_rank(); // empty case let expected = vec![0.0; 0]; diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 7f348bf9d2a05..56af14fb84ae5 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `row_number` that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::Range; +//! `row_number` window function implementation use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; @@ -27,31 +23,23 @@ use datafusion_common::arrow::compute::SortOptions; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::WindowFunction; -use datafusion_expr::{Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::OnceLock; -/// Create a [`WindowFunction`](Expr::WindowFunction) expression for -/// `row_number` user-defined window function. -pub fn row_number() -> Expr { - Expr::WindowFunction(WindowFunction::new(row_number_udwf(), vec![])) -} - -/// Singleton instance of `row_number`, ensures the UDWF is only created once. -#[allow(non_upper_case_globals)] -static STATIC_RowNumber: std::sync::OnceLock> = - std::sync::OnceLock::new(); - -/// Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for `row_number` -/// user-defined window function. -pub fn row_number_udwf() -> std::sync::Arc { - STATIC_RowNumber - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::WindowUDF::from(RowNumber::default())) - }) - .clone() -} +define_udwf_and_expr!( + RowNumber, + row_number, + "Returns a unique row number for each row in window partition beginning at 1." +); /// row_number expression #[derive(Debug)] @@ -74,6 +62,21 @@ impl Default for RowNumber { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_row_number_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Number of the current row within its partition, counting from 1.", + ) + .with_syntax_example("row_number()") + .build() + .unwrap() + }) +} + impl WindowUDFImpl for RowNumber { fn as_any(&self) -> &dyn Any { self @@ -87,7 +90,10 @@ impl WindowUDFImpl for RowNumber { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::::default()) } @@ -101,6 +107,10 @@ impl WindowUDFImpl for RowNumber { nulls_first: false, }) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_row_number_doc()) + } } /// State for the `row_number` built-in window function. @@ -156,7 +166,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; @@ -172,7 +182,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs new file mode 100644 index 0000000000000..3f8061dbea3e1 --- /dev/null +++ b/datafusion/functions-window/src/utils.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::Int64)?.try_into() +} + +pub(crate) fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .value() + .clone(); + Some(tmp) + } else { + None + }) +} + +pub(crate) fn get_unsigned_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::UInt64)?.try_into() +} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index ff1b926a9b822..70a988dbfefb5 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["regex_expressions", "uuid"] +string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] @@ -102,6 +102,11 @@ harness = false name = "to_timestamp" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "encoding" +required-features = ["encoding_expressions"] + [[bench]] harness = false name = "regx" @@ -112,6 +117,11 @@ harness = false name = "make_date" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "iszero" +required-features = ["math_expressions"] + [[bench]] harness = false name = "nullif" @@ -127,6 +137,16 @@ harness = false name = "to_char" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "isnan" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "signum" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -172,7 +192,17 @@ harness = false name = "character_length" required-features = ["unicode_expressions"] +[[bench]] +harness = false +name = "cot" +required-features = ["math_expressions"] + [[bench]] harness = false name = "strpos" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "trunc" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs new file mode 100644 index 0000000000000..e655d82dec914 --- /dev/null +++ b/datafusion/functions/benches/cot.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::cot; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let cot_fn = cot(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("cot f32 array: {}", size), |b| { + b.iter(|| black_box(cot_fn.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("cot f64 array: {}", size), |b| { + b.iter(|| black_box(cot_fn.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs new file mode 100644 index 0000000000000..d49235aac9383 --- /dev/null +++ b/datafusion/functions/benches/encoding.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::encoding; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let decode = encoding::decode(); + for size in [1024, 4096, 8192] { + let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + c.bench_function(&format!("base64_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("base64".into()); + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| black_box(decode.invoke(&args).unwrap())) + }); + + c.bench_function(&format!("hex_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("hex".into()); + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| black_box(decode.invoke(&args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs new file mode 100644 index 0000000000000..16bbe073daf04 --- /dev/null +++ b/datafusion/functions/benches/isnan.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::isnan; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let isnan = isnan(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("isnan f32 array: {}", size), |b| { + b.iter(|| black_box(isnan.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("isnan f64 array: {}", size), |b| { + b.iter(|| black_box(isnan.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs new file mode 100644 index 0000000000000..3348d172e1f20 --- /dev/null +++ b/datafusion/functions/benches/iszero.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::iszero; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let iszero = iszero(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("iszero f32 array: {}", size), |b| { + b.iter(|| black_box(iszero.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("iszero f64 array: {}", size), |b| { + b.iter(|| black_box(iszero.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 45bfa23511281..468d3d548bcf0 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,11 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +62,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,20 +87,56 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("regexp_like_1000", |b| { + c.bench_function("regexp_count_1000 string", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; b.iter(|| { black_box( - regexp_like::(&[ + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ Arc::clone(&data), Arc::clone(®ex), + Arc::clone(&start), Arc::clone(&flags), ]) - .expect("regexp_like should work on valid values"), + .expect("regexp_count should work on utf8view"), + ) + }) + }); + + c.bench_function("regexp_like_1000", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_like(&[Arc::clone(&data), Arc::clone(®ex), Arc::clone(&flags)]) + .expect("regexp_like should work on valid values"), ) }) }); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs new file mode 100644 index 0000000000000..9f8d8258c8230 --- /dev/null +++ b/datafusion/functions/benches/signum.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::signum; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let signum = signum(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("signum f32 array: {}", size), |b| { + b.iter(|| black_box(signum.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("signum f64 array: {}", size), |b| { + b.iter(|| black_box(signum.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index e734b6832f29c..5a87b34caf474 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -20,27 +20,123 @@ extern crate criterion; use std::sync::Arc; use arrow::array::builder::StringBuilder; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ColumnarValue; use datafusion_functions::datetime::to_timestamp; +fn data() -> StringArray { + let data: Vec<&str> = vec![ + "1997-01-31T09:26:56.123Z", + "1997-01-31T09:26:56.123-05:00", + "1997-01-31 09:26:56.123-05:00", + "2023-01-01 04:05:06.789 -08", + "1997-01-31T09:26:56.123", + "1997-01-31 09:26:56.123", + "1997-01-31 09:26:56", + "1997-01-31 13:26:56", + "1997-01-31 13:26:56+04:00", + "1997-01-31", + ]; + + StringArray::from(data) +} + +fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { + let mut inputs = StringBuilder::new(); + let mut format1_builder = StringBuilder::with_capacity(2, 10); + let mut format2_builder = StringBuilder::with_capacity(2, 10); + let mut format3_builder = StringBuilder::with_capacity(2, 10); + + inputs.append_value("1997-01-31T09:26:56.123Z"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); + + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); + + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); + + inputs.append_value("2023-01-01 04:05:06.789 -08"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); + + inputs.append_value("1997-01-31T09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S"); + + inputs.append_value("1997-01-31 092656"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S"); + + inputs.append_value("1997-01-31 092656+04:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); + + inputs.append_value("Sun Jul 8 00:34:60 2001"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d 00:00:00"); + + ( + inputs.finish(), + format1_builder.finish(), + format2_builder.finish(), + format3_builder.finish(), + ) +} fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("to_timestamp_no_formats", |b| { - let mut inputs = StringBuilder::new(); - inputs.append_value("1997-01-31T09:26:56.123Z"); - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - inputs.append_value("2023-01-01 04:05:06.789 -08"); - inputs.append_value("1997-01-31T09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56"); - inputs.append_value("1997-01-31 13:26:56"); - inputs.append_value("1997-01-31 13:26:56+04:00"); - inputs.append_value("1997-01-31"); - - let string_array = ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef); + c.bench_function("to_timestamp_no_formats_utf8", |b| { + let string_array = ColumnarValue::Array(Arc::new(data()) as ArrayRef); + + b.iter(|| { + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_largeutf8", |b| { + let data = cast(&data(), &DataType::LargeUtf8).unwrap(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); + + b.iter(|| { + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_utf8view", |b| { + let data = cast(&data(), &DataType::Utf8View).unwrap(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); b.iter(|| { black_box( @@ -51,67 +147,66 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("to_timestamp_with_formats", |b| { - let mut inputs = StringBuilder::new(); - let mut format1_builder = StringBuilder::with_capacity(2, 10); - let mut format2_builder = StringBuilder::with_capacity(2, 10); - let mut format3_builder = StringBuilder::with_capacity(2, 10); - - inputs.append_value("1997-01-31T09:26:56.123Z"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); - - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); - - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); - - inputs.append_value("2023-01-01 04:05:06.789 -08"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); - - inputs.append_value("1997-01-31T09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S"); - - inputs.append_value("1997-01-31 092656"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S"); - - inputs.append_value("1997-01-31 092656+04:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); - - inputs.append_value("Sun Jul 8 00:34:60 2001"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d 00:00:00"); + c.bench_function("to_timestamp_with_formats_utf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + + let args = [ + ColumnarValue::Array(Arc::new(inputs) as ArrayRef), + ColumnarValue::Array(Arc::new(format1) as ArrayRef), + ColumnarValue::Array(Arc::new(format2) as ArrayRef), + ColumnarValue::Array(Arc::new(format3) as ArrayRef), + ]; + b.iter(|| { + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_largeutf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + + let args = [ + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ]; + b.iter(|| { + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_utf8view", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); let args = [ - ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef + ), ]; b.iter(|| { black_box( diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs new file mode 100644 index 0000000000000..92a08abf3d326 --- /dev/null +++ b/datafusion/functions/benches/trunc.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::trunc; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let trunc = trunc(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("trunc f32 array: {}", size), |b| { + b.iter(|| black_box(trunc.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("trunc f64 array: {}", size), |b| { + b.iter(|| black_box(trunc.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index a1b74228a5039..a3e3feaa17e3d 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,17 +17,19 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use std::any::Any; - use arrow::datatypes::DataType; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, ScalarValue, }; +use std::any::Any; +use std::sync::OnceLock; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Expr, ExprSchemable, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, + Volatility, }; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -131,6 +133,39 @@ impl ScalarUDFImpl for ArrowCastFunc { // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_arrow_cast_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_arrow_cast_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description("Casts a value to a specific Arrow data type.") + .with_syntax_example("arrow_cast(expression, datatype)") + .with_sql_example( + r#"```sql +> select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +```"#, + ) + .with_argument("expression", "Expression to cast. The expression can be a constant, column, or function, and any combination of operators.") + .with_argument("datatype", "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]") + .build() + .unwrap() + }) } /// Returns the requested type from the arguments diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index cc5e7e619bd8a..a425aff6caad6 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -17,9 +17,11 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ArrowTypeOfFunc { @@ -69,4 +71,35 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { "{input_data_type}" )))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_arrowtypeof_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_arrowtypeof_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description( + "Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.", + ) + .with_syntax_example("arrow_typeof(expression)") + .with_sql_example( + r#"```sql +> select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +``` +"#, + ) + .with_argument("expression", "Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 2fa6d7c197ad7..a05f3f08232c4 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, ExprSchema, Result}; -use datafusion_expr::type_coercion::binary::type_union_resolution; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::binary::try_type_union_resolution; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct CoalesceFunc { @@ -136,10 +137,39 @@ impl ScalarUDFImpl for CoalesceFunc { if arg_types.is_empty() { return exec_err!("coalesce must have at least one argument"); } - let new_type = type_union_resolution(arg_types) - .unwrap_or(arg_types.first().unwrap().clone()); - Ok(vec![new_type; arg_types.len()]) + + try_type_union_resolution(arg_types) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_coalesce_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_coalesce_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.") + .with_syntax_example("coalesce(expression1[, ..., expression_n])") + .with_sql_example(r#"```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +```"#, + ) + .with_argument( + "expression1, expression_n", + "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index a51f895c5084b..c0af4d35966b1 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -23,10 +23,11 @@ use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, }; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct GetFieldFunc { @@ -133,7 +134,7 @@ impl ScalarUDFImpl for GetFieldFunc { DataType::Struct(fields) if fields.len() == 2 => { // Arrow's MapArray is essentially a ListArray of structs with two columns. They are // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second columnis the "value" column both here and in + // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); Ok(value_field.data_type().clone()) @@ -155,7 +156,7 @@ impl ScalarUDFImpl for GetFieldFunc { "Only UTF8 strings are valid as an indexed field in a struct" ), (DataType::Null, _) => Ok(DataType::Null), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } @@ -190,7 +191,7 @@ impl ScalarUDFImpl for GetFieldFunc { let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; // note that this array has more entries than the expected output/input size - // because maparray is flatten + // because map_array is flattened let original_data = map_array.entries().column(1).to_data(); let capacity = Capacities::Array(original_data.len()); let mut mutable = @@ -205,7 +206,7 @@ impl ScalarUDFImpl for GetFieldFunc { keys.slice(start, end-start). iter().enumerate(). find(|(_, t)| t.unwrap()); - if maybe_matched.is_none(){ + if maybe_matched.is_none() { mutable.extend_nulls(1); continue } @@ -224,14 +225,67 @@ impl ScalarUDFImpl for GetFieldFunc { } } (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index" + "get_field is only possible on struct with utf8 indexes. \ + Received with {name:?} index" ), (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index" + "get_field is only possible on maps with utf8 indexes or struct \ + with utf8 indexes. Received {dt:?} with {name:?} index" ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_getfield_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_getfield_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description(r#"Returns a field within a map or a struct with the given key. +Note: most users invoke `get_field` indirectly via field access +syntax such as `my_struct_col['field_name']` which results in a call to +`get_field(my_struct_col, 'field_name')`."#) + .with_syntax_example("get_field(expression1, expression2)") + .with_sql_example(r#"```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +``` + "#) + .with_argument( + "expression1", + "The map or struct to retrieve a field for." + ) + .with_argument( + "expression2", + "The field name in the map or struct to retrieve data for. Must evaluate to a string." + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 1c69f9c9b2f37..cf64c03766cbc 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -107,5 +107,6 @@ pub fn functions() -> Vec> { get_field(), coalesce(), version(), + r#struct(), ] } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 85c3327453556..b2c7f06d58685 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,11 +18,12 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use hashbrown::HashSet; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; /// put values in a struct array. fn named_struct_expr(args: &[ColumnarValue]) -> Result { @@ -123,7 +124,7 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type_from_exprs( &self, - args: &[datafusion_expr::Expr], + args: &[Expr], schema: &dyn datafusion_common::ExprSchema, _arg_types: &[DataType], ) -> Result { @@ -161,4 +162,46 @@ impl ScalarUDFImpl for NamedStructFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { named_struct_expr(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_named_struct_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_named_struct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRUCT) + .with_description("Returns an Arrow struct using the specified name and input expressions pairs.") + .with_syntax_example("named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])") + .with_sql_example(r#" +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` +"#) + .with_argument( + "expression_n_name", + "Name of the column field. Must be a constant string." + ) + .with_argument("expression_n_input", "Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 6fcfbd36416ef..f96ee1ea7a122 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -17,13 +17,15 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Documentation}; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; use datafusion_common::ScalarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct NullIfFunc { @@ -93,6 +95,47 @@ impl ScalarUDFImpl for NullIfFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { nullif_func(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nullif_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nullif_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce).") + .with_syntax_example("nullif(expression1, expression2)") + .with_sql_example(r#"```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } /// Implements NULLIF(expr1, expr2) diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index a09224acefcdf..16438e1b6254f 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -20,8 +20,11 @@ use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::sync::Arc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct NVLFunc { @@ -91,6 +94,46 @@ impl ScalarUDFImpl for NVLFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nvl_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nvl_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.") + .with_syntax_example("nvl(expression1, expression2)") + .with_sql_example(r#"```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to return if not null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } fn nvl_func(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 1144dc0fb7c56..cfcdb4480787c 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -20,11 +20,12 @@ use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; use datafusion_expr::{ - type_coercion::binary::comparison_coercion, ColumnarValue, ScalarUDFImpl, Signature, - Volatility, + type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, Volatility, }; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct NVL2Func { @@ -90,6 +91,50 @@ impl ScalarUDFImpl for NVL2Func { )?; Ok(vec![new_type; arg_types.len()]) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nvl2_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nvl2_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.") + .with_syntax_example("nvl2(expression1, expression2, expression3)") + .with_sql_example(r#"```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to test for null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression3", + "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } fn nvl2_func(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 889f191d592f5..717a74797c0b5 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -17,14 +17,14 @@ use arrow::datatypes::Field; use datafusion_common::Result; -use datafusion_common::{not_impl_err, Column, DFSchema, ScalarValue, TableReference}; +use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; use datafusion_expr::{lit, Expr}; use super::named_struct; -#[derive(Default)] +#[derive(Default, Debug)] pub struct CoreFunctionPlanner {} impl ExprPlanner for CoreFunctionPlanner { @@ -49,7 +49,7 @@ impl ExprPlanner for CoreFunctionPlanner { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( if is_named_struct { - crate::core::named_struct() + named_struct() } else { crate::core::r#struct() }, @@ -70,19 +70,20 @@ impl ExprPlanner for CoreFunctionPlanner { qualifier: Option<&TableReference>, nested_names: &[String], ) -> Result>> { - // TODO: remove when can support multiple nested identifiers - if nested_names.len() > 1 { - return not_impl_err!( - "Nested identifiers not yet supported for column {}", - Column::from((qualifier, field)).quoted_flat_name() - ); + let col = Expr::Column(Column::from((qualifier, field))); + + // Start with the base column expression + let mut expr = col; + + // Iterate over nested_names and create nested get_field expressions + for nested_name in nested_names { + let get_field_args = vec![expr, lit(ScalarValue::from(nested_name.clone()))]; + expr = Expr::ScalarFunction(ScalarFunction::new_udf( + crate::core::get_field(), + get_field_args, + )); } - let nested_name = nested_names[0].to_string(); - let col = Expr::Column(Column::from((qualifier, field))); - let get_field_args = vec![col, lit(ScalarValue::from(nested_name))]; - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf(crate::core::get_field(), get_field_args), - ))) + Ok(PlannerResult::Planned(expr)) } } diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index bdddbb81beabe..75d1d4eca6983 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -18,10 +18,11 @@ use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; fn array_struct(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. @@ -57,6 +58,7 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { #[derive(Debug)] pub struct StructFunc { signature: Signature, + aliases: Vec, } impl Default for StructFunc { @@ -69,6 +71,7 @@ impl StructFunc { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("row")], } } } @@ -81,6 +84,10 @@ impl ScalarUDFImpl for StructFunc { "struct" } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn signature(&self) -> &Signature { &self.signature } @@ -97,4 +104,56 @@ impl ScalarUDFImpl for StructFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { struct_expr(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_struct_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_struct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRUCT) + .with_description("Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc.") + .with_syntax_example("struct(expression1[, ..., expression_n])") + .with_sql_example(r#"For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `c1`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +-- use default names `c0`, `c1` +> select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ + +-- name the first field `field_a` +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ +``` +"#) + .with_argument( + "expression1, expression_n", + "Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 212349e689818..f726122c649ac 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -17,11 +17,14 @@ //! [`VersionFunc`]: Implementation of the `version` function. -use std::any::Any; - use arrow::datatypes::DataType; use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct VersionFunc { @@ -78,6 +81,33 @@ impl ScalarUDFImpl for VersionFunc { ); Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(version)))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_version_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_version_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description("Returns the version of DataFusion.") + .with_syntax_example("version()") + .with_sql_example( + r#"```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ +```"#, + ) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index c9dd3c1f56a29..0e43fb7785dfd 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -19,10 +19,12 @@ use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, }; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct DigestFunc { @@ -69,4 +71,48 @@ impl ScalarUDFImpl for DigestFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { digest(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_digest_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_digest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description( + "Computes the binary hash of an expression using the specified algorithm.", + ) + .with_syntax_example("digest(expression, algorithm)") + .with_sql_example( + r#"```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument( + "expression", Some("String")) + .with_argument( + "algorithm", + "String expression specifying algorithm to use. Must be one of: + +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3", + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index ccb6fbba80aad..062d63bcc0182 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -19,8 +19,12 @@ use crate::crypto::basic::md5; use arrow::datatypes::DataType; use datafusion_common::{plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct Md5Func { @@ -84,4 +88,32 @@ impl ScalarUDFImpl for Md5Func { fn invoke(&self, args: &[ColumnarValue]) -> Result { md5(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_md5_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_md5_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes an MD5 128-bit checksum for a string expression.") + .with_syntax_example("md5(expression)") + .with_sql_example( + r#"```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 2795c4a250041..39202d5bf6914 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -19,13 +19,18 @@ use super::basic::{sha224, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA224Func { signature: Signature, } + impl Default for SHA224Func { fn default() -> Self { Self::new() @@ -44,6 +49,31 @@ impl SHA224Func { } } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha224_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-224 hash of a binary string.") + .with_syntax_example("sha224(expression)") + .with_sql_example( + r#"```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for SHA224Func { fn as_any(&self) -> &dyn Any { self @@ -60,7 +90,12 @@ impl ScalarUDFImpl for SHA224Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha224(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha224_doc()) + } } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 0a3f3b26e4310..74deb3fc6caad 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -19,8 +19,12 @@ use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA256Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA256Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha256(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha256_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha256_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-256 hash of a binary string.") + .with_syntax_example("sha256(expression)") + .with_sql_example( + r#"```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index c3f7845ce7bd7..9b1e1ba9ec3cb 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -19,8 +19,12 @@ use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA384Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA384Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha384(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha384_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha384_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-384 hash of a binary string.") + .with_syntax_example("sha384(expression)") + .with_sql_example( + r#"```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index dc3bfac9d8bdb..c88579fd08eea 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -19,8 +19,12 @@ use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA512Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA512Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha512(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha512_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha512_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-512 hash of a binary string.") + .with_syntax_example("sha512(expression)") + .with_sql_example( + r#"```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +```"#, + ) + .with_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 89b40a3534d3a..6e3106a5bce63 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -18,15 +18,16 @@ use std::sync::Arc; use arrow::array::{ - Array, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, + StringViewArray, }; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::LocalResult::Single; use chrono::{DateTime, TimeZone, Utc}; -use itertools::Either; +use crate::strings::StringArrayType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{ exec_err, unwrap_or_internal_err, DataFusionError, Result, ScalarType, ScalarValue, @@ -41,14 +42,15 @@ pub(crate) fn string_to_timestamp_nanos_shim(s: &str) -> Result { string_to_timestamp_nanos(s).map_err(|e| e.into()) } -/// Checks that all the arguments from the second are of type [Utf8] or [LargeUtf8] +/// Checks that all the arguments from the second are of type [Utf8], [LargeUtf8] or [Utf8View] /// /// [Utf8]: DataType::Utf8 /// [LargeUtf8]: DataType::LargeUtf8 +/// [Utf8View]: DataType::Utf8View pub(crate) fn validate_data_types(args: &[ColumnarValue], name: &str) -> Result<()> { for (idx, a) in args.iter().skip(1).enumerate() { match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } _ => { @@ -178,26 +180,43 @@ pub(crate) fn string_to_timestamp_millis_formatted(s: &str, format: &str) -> Res .timestamp_millis()) } -pub(crate) fn handle<'a, O, F, S>( - args: &'a [ColumnarValue], +pub(crate) fn handle( + args: &[ColumnarValue], op: F, name: &str, ) -> Result where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str) -> Result, + F: Fn(&str) -> Result, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, + DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&StringViewArray, O, _>( + a.as_ref().as_string_view(), + op, + )?, + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, + ))), + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| op(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), @@ -205,11 +224,11 @@ where } } -// given an function that maps a `&str`, `&str` to an arrow native type, +// Given a function that maps a `&str`, `&str` to an arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. -pub(crate) fn handle_multiple<'a, O, F, S, M>( - args: &'a [ColumnarValue], +pub(crate) fn handle_multiple( + args: &[ColumnarValue], op: F, op2: M, name: &str, @@ -217,24 +236,24 @@ pub(crate) fn handle_multiple<'a, O, F, S, M>( where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, M: Fn(O::Native) -> O::Native, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // validate the column types for (pos, arg) in args.iter().enumerate() { match arg { ColumnarValue::Array(arg) => match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), }, ColumnarValue::Scalar(arg) => { match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View| DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), @@ -244,7 +263,7 @@ where } Ok(ColumnarValue::Array(Arc::new( - strings_to_primitive_function::(args, op, op2, name)?, + strings_to_primitive_function::(args, op, op2, name)?, ))) } other => { @@ -253,7 +272,9 @@ where }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? let a = unwrap_or_internal_err!(a); @@ -262,7 +283,9 @@ where for (pos, v) in args.iter().enumerate().skip(1) { let ColumnarValue::Scalar( - ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), + ScalarValue::Utf8View(x) + | ScalarValue::LargeUtf8(x) + | ScalarValue::Utf8(x), ) = v else { return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); @@ -299,18 +322,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not > 1 or -/// * the array arguments are not castable to a `GenericStringArray` or /// * the function `op` errors for all input -pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>( - args: &'a [ColumnarValue], +pub(crate) fn strings_to_primitive_function( + args: &[ColumnarValue], op: F, op2: F2, name: &str, ) -> Result> where O: ArrowPrimitiveType, - T: OffsetSizeTrait, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, F2: Fn(O::Native) -> O::Native, { if args.len() < 2 { @@ -321,50 +342,90 @@ where ); } - // this will throw the error if any of the array args are not castable to GenericStringArray - let data = args - .iter() - .map(|a| match a { - ColumnarValue::Array(a) => { - Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => { + let string_array = a.as_string_view(); + handle_array_op::( + &string_array, + &args[1..], + op, + op2, + ) } - ColumnarValue::Scalar(s) => match s { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), - other => exec_err!( - "Unexpected scalar type encountered '{other}' for function '{name}'" - ), - }, - }) - .collect::, &Option>>>>()?; - - let first_arg = &data.first().unwrap().left().unwrap(); + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + DataType::Utf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + other => exec_err!( + "Unsupported data type {other:?} for function substr,\ + expected Utf8View, Utf8 or LargeUtf8." + ), + }, + other => exec_err!( + "Received {} data type, expected only array", + other.data_type() + ), + } +} - first_arg +fn handle_array_op<'a, O, V, F, F2>( + first: &V, + args: &[ColumnarValue], + op: F, + op2: F2, +) -> Result> +where + V: StringArrayType<'a>, + O: ArrowPrimitiveType, + F: Fn(&str, &str) -> Result, + F2: Fn(O::Native) -> O::Native, +{ + first .iter() .enumerate() .map(|(pos, x)| { let mut val = None; - if let Some(x) = x { - let param_args = data.iter().skip(1); - - // go through the args and find the first successful result. Only the last - // failure will be returned if no successful result was received. - for param_arg in param_args { - // param_arg is an array, use the corresponding index into the array as the arg - // we're currently parsing - let p = *param_arg; - let r = if p.is_left() { - let p = p.left().unwrap(); - op(x, p.value(pos)) - } - // args is a scalar, use it directly - else if let Some(p) = p.right().unwrap() { - op(x, p.as_str()) - } else { - continue; - }; + for arg in args { + let v = match arg { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => Ok(a.as_string_view().value(pos)), + DataType::LargeUtf8 => Ok(a.as_string::().value(pos)), + DataType::Utf8 => Ok(a.as_string::().value(pos)), + other => exec_err!("Unexpected type encountered '{other}'"), + }, + ColumnarValue::Scalar(s) => match s { + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { + if let Some(v) = a { + Ok(v.as_str()) + } else { + continue; + } + } + other => { + exec_err!("Unexpected scalar type encountered '{other}'") + } + }, + }?; + let r = op(x, v); if r.is_ok() { val = Some(Ok(op2(r.unwrap()))); break; @@ -385,28 +446,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not 1 or -/// * the first argument is not castable to a `GenericStringArray` or /// * the function `op` errors -fn unary_string_to_primitive_function<'a, T, O, F>( - args: &[&'a dyn Array], +fn unary_string_to_primitive_function<'a, StringArrType, O, F>( + array: StringArrType, op: F, - name: &str, ) -> Result> where + StringArrType: StringArrayType<'a>, O: ArrowPrimitiveType, - T: OffsetSizeTrait, F: Fn(&'a str) -> Result, { - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let array = as_generic_string_array::(args[0])?; - // first map is the iterator, second is for the `Option<_>` array.iter().map(|x| x.map(&op).transpose()).collect() } diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 8b180ff41b91b..24046611a71f7 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -22,8 +22,12 @@ use arrow::datatypes::DataType::Date32; use chrono::{Datelike, NaiveDate}; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::OnceLock; #[derive(Debug)] pub struct CurrentDateFunc { @@ -95,4 +99,25 @@ impl ScalarUDFImpl for CurrentDateFunc { ScalarValue::Date32(days), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_date()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 803759d4e904e..4122b54b07e89 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct CurrentTimeFunc { @@ -84,4 +87,25 @@ impl ScalarUDFImpl for CurrentTimeFunc { ScalarValue::Time64Nanosecond(nano), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_time()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 997f1a36ad040..e335c4e097f78 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::NANOSECONDS; use arrow::array::types::{ @@ -35,10 +35,11 @@ use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; #[derive(Debug)] pub struct DateBinFunc { @@ -163,6 +164,44 @@ impl ScalarUDFImpl for DateBinFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_bin_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_bin_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. +"#) + .with_syntax_example("date_bin(interval, expression, origin-timestamp)") + .with_argument("interval", "Bin interval.") + .with_argument("expression", "Time expression to operate on. Can be a constant, column, or function.") + .with_argument("origin-timestamp", "Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century +") + .build() + .unwrap() + }) } enum Interval { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index b6a9a1c7e9dba..01e094bc4e0b7 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{Array, ArrayRef, Float64Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; @@ -37,9 +37,10 @@ use datafusion_common::cast::{ as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -217,6 +218,47 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_part_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns the specified part of the date as an integer.") + .with_syntax_example("date_part(part, expression)") + .with_argument( + "part", + r#"Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .with_alternative_syntax("extract(field FROM source)") + .build() + .unwrap() + }) } /// Invoke [`date_part`] and cast the result to Float64 diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index f4786b16685fa..4808f020e0ca3 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::ops::{Add, Sub}; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::{ as_datetime_with_timezone, timestamp_ns_to_datetime, @@ -36,12 +36,13 @@ use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; use chrono::{ DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, }; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; #[derive(Debug)] pub struct DateTruncFunc { @@ -241,6 +242,40 @@ impl ScalarUDFImpl for DateTruncFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_trunc_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_trunc_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Truncates a timestamp value to a specified precision.") + .with_syntax_example("date_trunc(precision, expression)") + .with_argument( + "precision", + r#"Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .build() + .unwrap() + }) } fn _date_trunc_coarse(granularity: &str, value: Option) -> Result> diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d36ebe735ee70..84aa9feec654b 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp}; use arrow::datatypes::TimeUnit::Second; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FromUnixtimeFunc { @@ -78,4 +81,24 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_from_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_from_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.") + .with_syntax_example("from_unixtime(expression)") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ded7b454f9eb8..c8ef349dfbeb5 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::builder::PrimitiveBuilder; use arrow::array::cast::AsArray; @@ -27,7 +27,10 @@ use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf use chrono::prelude::*; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct MakeDateFunc { @@ -86,9 +89,9 @@ impl ScalarUDFImpl for MakeDateFunc { ColumnarValue::Array(a) => Some(a.len()), }); - let years = args[0].cast_to(&DataType::Int32, None)?; - let months = args[1].cast_to(&DataType::Int32, None)?; - let days = args[2].cast_to(&DataType::Int32, None)?; + let years = args[0].cast_to(&Int32, None)?; + let months = args[1].cast_to(&Int32, None)?; + let days = args[2].cast_to(&Int32, None)?; let scalar_value_fn = |col: &ColumnarValue| -> Result { let ColumnarValue::Scalar(s) = col else { @@ -148,6 +151,47 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(value) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_make_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_make_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Make a date from year/month/day component parts.") + .with_syntax_example("make_date(year, month, day)") + .with_argument( + "year", + " Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", ) + .with_argument( + "month", + "Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +"#) + .build() + .unwrap() + }) } /// Converts the year/month/day fields to an `i32` representing the days from diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b2221215b94b7..c13bbfb181050 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,19 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct NowFunc { signature: Signature, + aliases: Vec, } impl Default for NowFunc { @@ -40,6 +44,7 @@ impl NowFunc { pub fn new() -> Self { Self { signature: Signature::uniform(0, vec![], Volatility::Stable), + aliases: vec!["current_timestamp".to_string()], } } } @@ -84,4 +89,32 @@ impl ScalarUDFImpl for NowFunc { ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { + false + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. +"#) + .with_syntax_example("now()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f2e5af978ca0b..f0c4a02c15230 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::cast::AsArray; use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; @@ -29,9 +29,10 @@ use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -53,34 +54,34 @@ impl ToCharFunc { vec![ Exact(vec![Date32, Utf8]), Exact(vec![Date64, Utf8]), + Exact(vec![Time64(Nanosecond), Utf8]), + Exact(vec![Time64(Microsecond), Utf8]), Exact(vec![Time32(Millisecond), Utf8]), Exact(vec![Time32(Second), Utf8]), - Exact(vec![Time64(Microsecond), Utf8]), - Exact(vec![Time64(Nanosecond), Utf8]), - Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![ - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Millisecond, None), Utf8]), + Exact(vec![Timestamp(Nanosecond, None), Utf8]), Exact(vec![ - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), Exact(vec![Timestamp(Microsecond, None), Utf8]), Exact(vec![ - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Nanosecond, None), Utf8]), + Exact(vec![Timestamp(Millisecond, None), Utf8]), Exact(vec![ - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Duration(Second), Utf8]), - Exact(vec![Duration(Millisecond), Utf8]), - Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![Duration(Nanosecond), Utf8]), + Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Duration(Millisecond), Utf8]), + Exact(vec![Duration(Second), Utf8]), ], Volatility::Immutable, ), @@ -137,6 +138,42 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_char_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_char_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported.") + .with_syntax_example("to_char(expression, format)") + .with_argument( + "expression", + " Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration." + ) + .with_argument( + "format", + "A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +"#) + .build() + .unwrap() + }) } fn _build_format_options<'a>( @@ -185,10 +222,7 @@ fn _to_char_scalar( if is_scalar_expression { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } else { - return Ok(ColumnarValue::Array(new_null_array( - &DataType::Utf8, - array.len(), - ))); + return Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))); } } diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 288641b84dd7e..82e189698c5e3 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - +use crate::datetime::common::*; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Date32; +use arrow::datatypes::DataType::*; use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; - -use crate::datetime::common::*; use datafusion_common::error::DataFusionError; use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ToDateFunc { @@ -77,6 +79,50 @@ impl ToDateFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#"Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. + +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. +"#) + .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')") + .with_sql_example(r#"```sql +> select to_date('2023-01-31'); ++-----------------------------+ +| to_date(Utf8("2023-01-31")) | ++-----------------------------+ +| 2023-01-31 | ++-----------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +"#) + .with_standard_argument("expression", Some("String")) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned.", + ) + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for ToDateFunc { fn as_any(&self) -> &dyn Any { self @@ -105,25 +151,28 @@ impl ScalarUDFImpl for ToDateFunc { } match args[0].data_type() { - DataType::Int32 - | DataType::Int64 - | DataType::Null - | DataType::Float64 - | DataType::Date32 - | DataType::Date64 => args[0].cast_to(&DataType::Date32, None), - DataType::Utf8 => self.to_date(args), + Int32 | Int64 | Null | Float64 | Date32 | Date64 => { + args[0].cast_to(&Date32, None) + } + Utf8View | LargeUtf8 | Utf8 => self.to_date(args), other => { exec_err!("Unsupported data type {:?} for function to_date", other) } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_date_doc()) + } } #[cfg(test)] mod tests { + use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; use super::ToDateFunc; @@ -154,9 +203,17 @@ mod tests { ]; for tc in &test_cases { - let date_scalar = ScalarValue::Utf8(Some(tc.date_str.to_string())); - let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + test_scalar(ScalarValue::Utf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::LargeUtf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::Utf8View(Some(tc.date_str.to_string())), tc); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { + let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(sv)]); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -170,6 +227,33 @@ mod tests { _ => panic!("Could not convert '{}' to Date", tc.date_str), } } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.date_str]); + let to_date_result = + ToDateFunc::new().invoke(&[ColumnarValue::Array(Arc::new(date_array))]); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, + a.as_ref(), + "{}: to_date created wrong value", + tc.name + ); + } + _ => panic!("Could not convert '{}' to Date", tc.date_str), + } + } } #[test] @@ -221,12 +305,26 @@ mod tests { ]; for tc in &test_cases { - let formatted_date_scalar = - ScalarValue::Utf8(Some(tc.formatted_date.to_string())); + test_scalar(ScalarValue::Utf8(Some(tc.formatted_date.to_string())), tc); + test_scalar( + ScalarValue::LargeUtf8(Some(tc.formatted_date.to_string())), + tc, + ); + test_scalar( + ScalarValue::Utf8View(Some(tc.formatted_date.to_string())), + tc, + ); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); let to_date_result = ToDateFunc::new().invoke(&[ - ColumnarValue::Scalar(formatted_date_scalar), + ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ]); @@ -241,6 +339,41 @@ mod tests { ), } } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.formatted_date]); + let format_array = A::from(vec![tc.format_str]); + + let to_date_result = ToDateFunc::new().invoke(&[ + ColumnarValue::Array(Arc::new(date_array)), + ColumnarValue::Array(Arc::new(format_array)), + ]); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, a.as_ref(), + "{}: to_date created wrong value for date '{}' with format string '{}'", + tc.name, + tc.formatted_date, + tc.format_str + ); + } + _ => panic!( + "Could not convert '{}' with format string '{}'to Date: {:?}", + tc.formatted_date, tc.format_str, to_date_result + ), + } + } } #[test] diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 0e33da14547e4..376cb6f5f2f83 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::ops::Add; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::timezone::Tz; use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; @@ -31,7 +31,10 @@ use arrow::datatypes::{ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; /// A UDF function that converts a timezone-aware timestamp to local time (with no offset or /// timezone information). In other words, this function strips off the timezone from the timestamp, @@ -65,7 +68,7 @@ impl ToLocalTimeFunc { let time_value = &args[0]; let arg_type = time_value.data_type(); match arg_type { - DataType::Timestamp(_, None) => { + Timestamp(_, None) => { // if no timezone specified, just return the input Ok(time_value.clone()) } @@ -75,7 +78,7 @@ impl ToLocalTimeFunc { // for more details. // // Then remove the timezone in return type, i.e. return None - DataType::Timestamp(_, Some(timezone)) => { + Timestamp(_, Some(timezone)) => { let tz: Tz = timezone.parse()?; match time_value { @@ -351,6 +354,72 @@ impl ScalarUDFImpl for ToLocalTimeFunc { _ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"), } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_local_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_local_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes.") + .with_syntax_example("to_local_time(expression)") + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function." + ) + .with_sql_example(r#"```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +```"#) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index cbb6f37603d27..60482ee3c74a6 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -16,19 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use datafusion_common::{exec_err, Result, ScalarType}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use crate::datetime::common::*; +use datafusion_common::{exec_err, Result, ScalarType}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct ToTimestampFunc { @@ -162,16 +164,16 @@ impl ScalarUDFImpl for ToTimestampFunc { } match args[0].data_type() { - DataType::Int32 | DataType::Int64 => args[0] + Int32 | Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), - DataType::Null | DataType::Float64 | Timestamp(_, None) => { + Null | Float64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp") } other => { @@ -182,6 +184,50 @@ impl ScalarUDFImpl for ToTimestampFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_timestamp_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +"#) + .with_syntax_example("to_timestamp(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampSecondsFunc { @@ -215,13 +261,11 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } - DataType::Timestamp(_, Some(tz)) => { - args[0].cast_to(&Timestamp(Second, Some(tz)), None) - } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => args[0].cast_to(&Timestamp(Second, Some(tz)), None), + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_seconds") } other => { @@ -232,6 +276,46 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_seconds_doc()) + } +} + +static TO_TIMESTAMP_SECONDS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_seconds_doc() -> &'static Documentation { + TO_TIMESTAMP_SECONDS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_seconds(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMillisFunc { @@ -265,13 +349,13 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Millisecond, Some(tz)), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_millis") } other => { @@ -282,6 +366,46 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_millis_doc()) + } +} + +static TO_TIMESTAMP_MILLIS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_millis_doc() -> &'static Documentation { + TO_TIMESTAMP_MILLIS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_millis(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMicrosFunc { @@ -315,13 +439,13 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Microsecond, Some(tz)), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_micros") } other => { @@ -332,6 +456,46 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_micros_doc()) + } +} + +static TO_TIMESTAMP_MICROS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_micros_doc() -> &'static Documentation { + TO_TIMESTAMP_MICROS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_micros(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampNanosFunc { @@ -365,13 +529,13 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Timestamp(_, Some(tz)) => { + Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) } - DataType::Utf8 => { + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_nanos") } other => { @@ -382,6 +546,46 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_nanos_doc()) + } +} + +static TO_TIMESTAMP_NANOS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_nanos_doc() -> &'static Documentation { + TO_TIMESTAMP_NANOS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_nanos(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } /// Returns the return type for the to_timestamp_* function, preserving @@ -804,7 +1008,7 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); + assert!(matches!(rt, Timestamp(_, Some(_)))); let res = udf .invoke(&[array.clone()]) @@ -814,7 +1018,7 @@ mod tests { _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, Some(_)))); + assert!(matches!(ty, Timestamp(_, Some(_)))); } } @@ -847,7 +1051,7 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, None))); + assert!(matches!(rt, Timestamp(_, None))); let res = udf .invoke(&[array.clone()]) @@ -857,7 +1061,7 @@ mod tests { _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, None))); + assert!(matches!(ty, Timestamp(_, None))); } } } @@ -933,10 +1137,7 @@ mod tests { .expect("that to_timestamp with format args parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 1); - assert!(matches!( - parsed_array.data_type(), - DataType::Timestamp(_, None) - )); + assert!(matches!(parsed_array.data_type(), Timestamp(_, None))); match time_unit { Nanosecond => { diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 396dadccb4b3e..10f0f87a4ab16 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::{DataType, TimeUnit}; +use std::any::Any; +use std::sync::OnceLock; +use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use super::to_timestamp::ToTimestampSecondsFunc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct ToUnixtimeFunc { @@ -86,4 +88,42 @@ impl ScalarUDFImpl for ToUnixtimeFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.") + .with_syntax_example("to_unixtime(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ).with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.") + .with_sql_example(r#" +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` +"#) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 5b80c908cfc31..4f91879f94db7 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -18,9 +18,12 @@ //! Encoding expressions use arrow::{ - array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait, StringArray}, - datatypes::DataType, + array::{ + Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray, + }, + datatypes::{ByteArrayType, DataType}, }; +use arrow_buffer::{Buffer, OffsetBufferBuilder}; use base64::{engine::general_purpose, Engine as _}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, @@ -28,10 +31,11 @@ use datafusion_common::{ }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; +use datafusion_expr::{ColumnarValue, Documentation}; +use std::sync::{Arc, OnceLock}; use std::{fmt, str::FromStr}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_BINARY_STRING; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -54,6 +58,22 @@ impl EncodeFunc { } } +static ENCODE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_encode_doc() -> &'static Documentation { + ENCODE_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_BINARY_STRING) + .with_description("Encode binary data into a textual representation.") + .with_syntax_example("encode(expression, format)") + .with_argument("expression", "Expression containing string or binary data") + .with_argument("format", "Supported formats are: `base64`, `hex`") + .with_related_udf("decode") + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for EncodeFunc { fn as_any(&self) -> &dyn Any { self @@ -100,6 +120,10 @@ impl ScalarUDFImpl for EncodeFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_encode_doc()) + } } #[derive(Debug)] @@ -121,6 +145,22 @@ impl DecodeFunc { } } +static DECODE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_decode_doc() -> &'static Documentation { + DECODE_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_BINARY_STRING) + .with_description("Decode binary data from textual representation in string.") + .with_syntax_example("decode(expression, format)") + .with_argument("expression", "Expression containing encoded string data") + .with_argument("format", "Same arguments as [encode](#encode)") + .with_related_udf("encode") + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for DecodeFunc { fn as_any(&self) -> &dyn Any { self @@ -167,6 +207,10 @@ impl ScalarUDFImpl for DecodeFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_decode_doc()) + } } #[derive(Debug, Copy, Clone)] @@ -245,16 +289,22 @@ fn base64_encode(input: &[u8]) -> String { general_purpose::STANDARD_NO_PAD.encode(input) } -fn hex_decode(input: &[u8]) -> Result> { - hex::decode(input).map_err(|e| { +fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { + // only write input / 2 bytes to buf + let out_len = input.len() / 2; + let buf = &mut buf[..out_len]; + hex::decode_to_slice(input, buf).map_err(|e| { DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) - }) + })?; + Ok(out_len) } -fn base64_decode(input: &[u8]) -> Result> { - general_purpose::STANDARD_NO_PAD.decode(input).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) - }) +fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { + general_purpose::STANDARD_NO_PAD + .decode_slice(input, buf) + .map_err(|e| { + DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + }) } macro_rules! encode_to_array { @@ -267,14 +317,35 @@ macro_rules! encode_to_array { }}; } -macro_rules! decode_to_array { - ($METHOD: ident, $INPUT:expr) => {{ - let binary_array: BinaryArray = $INPUT - .iter() - .map(|x| x.map(|x| $METHOD(x.as_ref())).transpose()) - .collect::>()?; - Arc::new(binary_array) - }}; +fn decode_to_array( + method: F, + input: &GenericByteArray, + conservative_upper_bound_size: usize, +) -> Result +where + F: Fn(&[u8], &mut [u8]) -> Result, +{ + let mut values = vec![0; conservative_upper_bound_size]; + let mut offsets = OffsetBufferBuilder::new(input.len()); + let mut total_bytes_decoded = 0; + for v in input { + if let Some(v) = v { + let cursor = &mut values[total_bytes_decoded..]; + let decoded = method(v.as_ref(), cursor)?; + total_bytes_decoded += decoded; + offsets.push_length(decoded); + } else { + offsets.push_length(0); + } + } + // We reserved an upper bound size for the values buffer, but we only use the actual size + values.truncate(total_bytes_decoded); + let binary_array = BinaryArray::try_new( + offsets.finish(), + Buffer::from_vec(values), + input.nulls().cloned(), + )?; + Ok(Arc::new(binary_array)) } impl Encoding { @@ -381,10 +452,7 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_binary_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } @@ -393,12 +461,29 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } + + fn decode_byte_array( + &self, + input_value: &GenericByteArray, + ) -> Result { + match self { + Self::Base64 => { + let upper_bound = + base64::decoded_len_estimate(input_value.values().len()); + decode_to_array(base64_decode, input_value, upper_bound) + } + Self::Hex => { + // Calculate the upper bound for decoded byte size + // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded + // So the upper bound is half the length of the input values. + let upper_bound = input_value.values().len() / 2; + decode_to_array(hex_decode, input_value, upper_bound) + } + } + } } impl fmt::Display for Encoding { diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index bb680f3c67dea..91f9449953e96 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -92,9 +92,6 @@ pub mod macros; pub mod string; make_stub_package!(string, "string_expressions"); -#[cfg(feature = "string_expressions")] -mod regexp_common; - /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] @@ -138,6 +135,8 @@ make_stub_package!(unicode, "unicode_expressions"); #[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] pub mod planner; +pub mod strings; + mod utils; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e47818bc86a4b..9bc038e71edc8 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -112,26 +112,6 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - /// Downcast an argument to a specific array type, returning an internal error /// if the cast fails /// @@ -161,19 +141,21 @@ macro_rules! downcast_arg { /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; #[derive(Debug)] pub struct $UDF { @@ -228,26 +210,17 @@ macro_rules! make_math_unary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -255,8 +228,13 @@ macro_rules! make_math_unary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; @@ -273,19 +251,21 @@ macro_rules! make_math_unary_udf { /// $BINARY_FUNC: the binary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::TypeSignature::*; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::TypeSignature; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; #[derive(Debug)] pub struct $UDF { @@ -298,8 +278,8 @@ macro_rules! make_math_binary_udf { Self { signature: Signature::one_of( vec![ - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), + TypeSignature::Exact(vec![Float32, Float32]), + TypeSignature::Exact(vec![Float64, Float64]), ], Volatility::Immutable, ), @@ -338,25 +318,27 @@ macro_rules! make_math_binary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -364,49 +346,14 @@ macro_rules! make_math_binary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; } - -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index f7a17f0caf947..5511a57d85669 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -18,7 +18,7 @@ //! math expressions use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, @@ -28,8 +28,11 @@ use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; type MathArrayFunction = fn(&Vec) -> Result; @@ -184,4 +187,22 @@ impl ScalarUDFImpl for AbsFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_abs_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_abs_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the absolute value of a number.") + .with_syntax_example("abs(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 66219960d9a2f..eded50a20d8d8 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -16,17 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; - -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct CotFunc { @@ -39,6 +39,20 @@ impl Default for CotFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cot_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cotangent of a number.") + .with_syntax_example(r#"cot(numeric_expression)"#) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + impl CotFunc { pub fn new() -> Self { use DataType::*; @@ -77,6 +91,10 @@ impl ScalarUDFImpl for CotFunc { } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_cot_doc()) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(cot, vec![])(args) } @@ -85,18 +103,16 @@ impl ScalarUDFImpl for CotFunc { ///cot SQL function fn cot(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float64Array, - { compute_cot64 } - )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float32Array, - { compute_cot32 } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), + ) as ArrayRef), + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function cot"), } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 74ad2c738a93c..bacdf47524f4f 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -20,14 +20,17 @@ use arrow::{ error::ArrowError, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FactorialFunc { @@ -68,12 +71,30 @@ impl ScalarUDFImpl for FactorialFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(factorial, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_factorial_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_factorial_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Factorial. Returns 1 if value is less than 2.") + .with_syntax_example("factorial(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => { + Int64 => { let arg = downcast_arg!((&args[0]), "value", Int64Array); Ok(arg .iter() diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 10faf9f390bb3..f4edef3acca38 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -19,15 +19,17 @@ use arrow::array::{ArrayRef, Int64Array}; use arrow::error::ArrowError; use std::any::Any; use std::mem::swap; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct GcdFunc { @@ -69,6 +71,27 @@ impl ScalarUDFImpl for GcdFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(gcd, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_gcd_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_gcd_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.", + ) + .with_syntax_example("gcd(expression_x, expression_y)") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) + .build() + .unwrap() + }) } /// Gcd SQL function diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index e6a7280533593..7e5d4fe77ffa0 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -16,16 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType::{Boolean, Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::make_scalar_function; @@ -72,25 +74,39 @@ impl ScalarUDFImpl for IsZeroFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(iszero, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_iszero_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_iszero_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", + ) + .with_syntax_example("iszero(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } /// Iszero SQL function pub fn iszero(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } + Float64 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } + Float32 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function iszero"), diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 21c201657e906..64b07ce606f2f 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, Int64Array}; use arrow::datatypes::DataType; @@ -24,8 +24,10 @@ use arrow::datatypes::DataType::Int64; use arrow::error::ArrowError; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use super::gcd::unsigned_gcd; use crate::utils::make_scalar_function; @@ -70,6 +72,27 @@ impl ScalarUDFImpl for LcmFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(lcm, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lcm_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lcm_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", + ) + .with_syntax_example("lcm(expression_x, expression_y)") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) + .build() + .unwrap() + }) } /// Lcm SQL function diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ad7cff1f7149f..9d2e1be3df9d5 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -18,20 +18,22 @@ //! Math function: `log()`. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{lit, ColumnarValue, Expr, ScalarUDF, TypeSignature::*}; +use datafusion_expr::{ + lit, ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature::*, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -45,6 +47,22 @@ impl Default for LogFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_log_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.") + .with_syntax_example(r#"log(base, numeric_expression) +log(numeric_expression)"#) + .with_standard_argument("base", Some("Base numeric")) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + impl LogFunc { pub fn new() -> Self { use DataType::*; @@ -121,37 +139,40 @@ impl ScalarUDFImpl for LogFunc { let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) + } + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( + x.as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } @@ -164,6 +185,10 @@ impl ScalarUDFImpl for LogFunc { Ok(ColumnarValue::Array(arr)) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_log_doc()) + } + /// Simplify the `log` function by the relevant rules: /// 1. Log(a, 1) ===> 0 /// 2. Log(a, Power(a, b)) ===> b @@ -236,12 +261,192 @@ mod tests { use super::*; + use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; + #[test] + #[should_panic] + fn test_log_invalid_base_type() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ]; + + let _ = LogFunc::new().invoke(&args); + } + + #[test] + fn test_log_invalid_value() { + let args = [ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ]; + + let result = LogFunc::new().invoke(&args); + result.expect_err("expected error"); + } + + #[test] + fn test_log_scalar_f32_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f32() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 5.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f64_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f32_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + #[test] fn test_log_f64() { let args = [ diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index b221fb900cfa3..1452bfdee5a08 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -47,7 +47,8 @@ make_math_unary_udf!( acos, acos, super::acos_order, - super::bounds::acos_bounds + super::bounds::acos_bounds, + super::get_acos_doc ); make_math_unary_udf!( AcoshFunc, @@ -55,7 +56,8 @@ make_math_unary_udf!( acosh, acosh, super::acosh_order, - super::bounds::acosh_bounds + super::bounds::acosh_bounds, + super::get_acosh_doc ); make_math_unary_udf!( AsinFunc, @@ -63,7 +65,8 @@ make_math_unary_udf!( asin, asin, super::asin_order, - super::bounds::asin_bounds + super::bounds::asin_bounds, + super::get_asin_doc ); make_math_unary_udf!( AsinhFunc, @@ -71,7 +74,8 @@ make_math_unary_udf!( asinh, asinh, super::asinh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_asinh_doc ); make_math_unary_udf!( AtanFunc, @@ -79,7 +83,8 @@ make_math_unary_udf!( atan, atan, super::atan_order, - super::bounds::atan_bounds + super::bounds::atan_bounds, + super::get_atan_doc ); make_math_unary_udf!( AtanhFunc, @@ -87,16 +92,25 @@ make_math_unary_udf!( atanh, atanh, super::atanh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_atanh_doc +); +make_math_binary_udf!( + Atan2, + ATAN2, + atan2, + atan2, + super::atan2_order, + super::get_atan2_doc ); -make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, super::atan2_order); make_math_unary_udf!( CbrtFunc, CBRT, cbrt, cbrt, super::cbrt_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_cbrt_doc ); make_math_unary_udf!( CeilFunc, @@ -104,7 +118,8 @@ make_math_unary_udf!( ceil, ceil, super::ceil_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_ceil_doc ); make_math_unary_udf!( CosFunc, @@ -112,7 +127,8 @@ make_math_unary_udf!( cos, cos, super::cos_order, - super::bounds::cos_bounds + super::bounds::cos_bounds, + super::get_cos_doc ); make_math_unary_udf!( CoshFunc, @@ -120,7 +136,8 @@ make_math_unary_udf!( cosh, cosh, super::cosh_order, - super::bounds::cosh_bounds + super::bounds::cosh_bounds, + super::get_cosh_doc ); make_udf_function!(cot::CotFunc, COT, cot); make_math_unary_udf!( @@ -129,7 +146,8 @@ make_math_unary_udf!( degrees, to_degrees, super::degrees_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_degrees_doc ); make_math_unary_udf!( ExpFunc, @@ -137,7 +155,8 @@ make_math_unary_udf!( exp, exp, super::exp_order, - super::bounds::exp_bounds + super::bounds::exp_bounds, + super::get_exp_doc ); make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); make_math_unary_udf!( @@ -146,7 +165,8 @@ make_math_unary_udf!( floor, floor, super::floor_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_floor_doc ); make_udf_function!(log::LogFunc, LOG, log); make_udf_function!(gcd::GcdFunc, GCD, gcd); @@ -159,7 +179,8 @@ make_math_unary_udf!( ln, ln, super::ln_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_ln_doc ); make_math_unary_udf!( Log2Func, @@ -167,7 +188,8 @@ make_math_unary_udf!( log2, log2, super::log2_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_log2_doc ); make_math_unary_udf!( Log10Func, @@ -175,7 +197,8 @@ make_math_unary_udf!( log10, log10, super::log10_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_log10_doc ); make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); make_udf_function!(pi::PiFunc, PI, pi); @@ -186,7 +209,8 @@ make_math_unary_udf!( radians, to_radians, super::radians_order, - super::bounds::radians_bounds + super::bounds::radians_bounds, + super::get_radians_doc ); make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); @@ -197,7 +221,8 @@ make_math_unary_udf!( sin, sin, super::sin_order, - super::bounds::sin_bounds + super::bounds::sin_bounds, + super::get_sin_doc ); make_math_unary_udf!( SinhFunc, @@ -205,7 +230,8 @@ make_math_unary_udf!( sinh, sinh, super::sinh_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_sinh_doc ); make_math_unary_udf!( SqrtFunc, @@ -213,7 +239,8 @@ make_math_unary_udf!( sqrt, sqrt, super::sqrt_order, - super::bounds::sqrt_bounds + super::bounds::sqrt_bounds, + super::get_sqrt_doc ); make_math_unary_udf!( TanFunc, @@ -221,7 +248,8 @@ make_math_unary_udf!( tan, tan, super::tan_order, - super::bounds::unbounded_bounds + super::bounds::unbounded_bounds, + super::get_tan_doc ); make_math_unary_udf!( TanhFunc, @@ -229,7 +257,8 @@ make_math_unary_udf!( tanh, tanh, super::tanh_order, - super::bounds::tanh_bounds + super::bounds::tanh_bounds, + super::get_tanh_doc ); make_udf_function!(trunc::TruncFunc, TRUNC, trunc); diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 52f2ec5171982..19c85f4b6e3ce 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -15,9 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::sync::OnceLock; + use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::Documentation; /// Non-increasing on the interval \[−1, 1\], undefined otherwise. pub fn acos_order(input: &[ExprProperties]) -> Result { @@ -34,6 +38,20 @@ pub fn acos_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ACOS: OnceLock = OnceLock::new(); + +pub fn get_acos_doc() -> &'static Documentation { + DOCUMENTATION_ACOS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc cosine or inverse cosine of a number.") + .with_syntax_example("acos(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 1, undefined otherwise. pub fn acosh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -51,6 +69,22 @@ pub fn acosh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ACOSH: OnceLock = OnceLock::new(); + +pub fn get_acosh_doc() -> &'static Documentation { + DOCUMENTATION_ACOSH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.", + ) + .with_syntax_example("acosh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing on the interval \[−1, 1\], undefined otherwise. pub fn asin_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -66,16 +100,60 @@ pub fn asin_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ASIN: OnceLock = OnceLock::new(); + +pub fn get_asin_doc() -> &'static Documentation { + DOCUMENTATION_ASIN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc sine or inverse sine of a number.") + .with_syntax_example("asin(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn asinh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_ASINH: OnceLock = OnceLock::new(); + +pub fn get_asinh_doc() -> &'static Documentation { + DOCUMENTATION_ASINH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic sine or inverse hyperbolic sine of a number.", + ) + .with_syntax_example("asinh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn atan_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_ATAN: OnceLock = OnceLock::new(); + +pub fn get_atan_doc() -> &'static Documentation { + DOCUMENTATION_ATAN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc tangent or inverse tangent of a number.") + .with_syntax_example("atan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing on the interval \[−1, 1\], undefined otherwise. pub fn atanh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -91,22 +169,87 @@ pub fn atanh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_ATANH: OnceLock = OnceLock::new(); + +pub fn get_atanh_doc() -> &'static Documentation { + DOCUMENTATION_ATANH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.", + ) + .with_syntax_example("atanh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Order depends on the quadrant. // TODO: Implement ordering rule of the ATAN2 function. pub fn atan2_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_ATANH2: OnceLock = OnceLock::new(); + +pub fn get_atan2_doc() -> &'static Documentation { + DOCUMENTATION_ATANH2.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the arc tangent or inverse tangent of `expression_y / expression_x`.", + ) + .with_syntax_example("atan2(expression_y, expression_x)") + .with_argument("expression_y", r#"First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .with_argument("expression_x", r#"Second numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn cbrt_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_CBRT: OnceLock = OnceLock::new(); + +pub fn get_cbrt_doc() -> &'static Documentation { + DOCUMENTATION_CBRT.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cube root of a number.") + .with_syntax_example("cbrt(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn ceil_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_CEIL: OnceLock = OnceLock::new(); + +pub fn get_ceil_doc() -> &'static Documentation { + DOCUMENTATION_CEIL.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the nearest integer greater than or equal to a number.", + ) + .with_syntax_example("ceil(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the ATAN2 function. @@ -114,6 +257,20 @@ pub fn cos_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_COS: OnceLock = OnceLock::new(); + +pub fn get_cos_doc() -> &'static Documentation { + DOCUMENTATION_COS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cosine of a number.") + .with_syntax_example("cos(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. pub fn cosh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -130,21 +287,79 @@ pub fn cosh_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_COSH: OnceLock = OnceLock::new(); + +pub fn get_cosh_doc() -> &'static Documentation { + DOCUMENTATION_COSH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic cosine of a number.") + .with_syntax_example("cosh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing function that converts radians to degrees. pub fn degrees_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_DEGREES: OnceLock = OnceLock::new(); + +pub fn get_degrees_doc() -> &'static Documentation { + DOCUMENTATION_DEGREES.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Converts radians to degrees.") + .with_syntax_example("degrees(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn exp_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_EXP: OnceLock = OnceLock::new(); + +pub fn get_exp_doc() -> &'static Documentation { + DOCUMENTATION_EXP.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-e exponential of a number.") + .with_syntax_example("exp(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn floor_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_FLOOR: OnceLock = OnceLock::new(); + +pub fn get_floor_doc() -> &'static Documentation { + DOCUMENTATION_FLOOR.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the nearest integer less than or equal to a number.", + ) + .with_syntax_example("floor(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn ln_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -159,6 +374,20 @@ pub fn ln_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LN: OnceLock = OnceLock::new(); + +pub fn get_ln_doc() -> &'static Documentation { + DOCUMENTATION_LN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the natural logarithm of a number.") + .with_syntax_example("ln(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn log2_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -173,6 +402,20 @@ pub fn log2_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LOG2: OnceLock = OnceLock::new(); + +pub fn get_log2_doc() -> &'static Documentation { + DOCUMENTATION_LOG2.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-2 logarithm of a number.") + .with_syntax_example("log2(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn log10_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -187,11 +430,39 @@ pub fn log10_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_LOG10: OnceLock = OnceLock::new(); + +pub fn get_log10_doc() -> &'static Documentation { + DOCUMENTATION_LOG10.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-10 logarithm of a number.") + .with_syntax_example("log10(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers x. pub fn radians_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_RADIONS: OnceLock = OnceLock::new(); + +pub fn get_radians_doc() -> &'static Documentation { + DOCUMENTATION_RADIONS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Converts degrees to radians.") + .with_syntax_example("radians(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the SIN function. @@ -199,11 +470,39 @@ pub fn sin_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_SIN: OnceLock = OnceLock::new(); + +pub fn get_sin_doc() -> &'static Documentation { + DOCUMENTATION_SIN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the sine of a number.") + .with_syntax_example("sin(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn sinh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } +static DOCUMENTATION_SINH: OnceLock = OnceLock::new(); + +pub fn get_sinh_doc() -> &'static Documentation { + DOCUMENTATION_SINH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic sine of a number.") + .with_syntax_example("sinh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn sqrt_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -218,6 +517,20 @@ pub fn sqrt_order(input: &[ExprProperties]) -> Result { } } +static DOCUMENTATION_SQRT: OnceLock = OnceLock::new(); + +pub fn get_sqrt_doc() -> &'static Documentation { + DOCUMENTATION_SQRT.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the square root of a number.") + .with_syntax_example("sqrt(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing between vertical asymptotes at x = k * π ± π / 2 for any /// integer k. // TODO: Implement ordering rule of the TAN function. @@ -225,7 +538,35 @@ pub fn tan_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } +static DOCUMENTATION_TAN: OnceLock = OnceLock::new(); + +pub fn get_tan_doc() -> &'static Documentation { + DOCUMENTATION_TAN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the tangent of a number.") + .with_syntax_example("tan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + /// Non-decreasing for all real numbers. pub fn tanh_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } + +static DOCUMENTATION_TANH: OnceLock = OnceLock::new(); + +pub fn get_tanh_doc() -> &'static Documentation { + DOCUMENTATION_TANH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic tangent of a number.") + .with_syntax_example("tanh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 2bd704a7de2e6..c1dd1aacc35a3 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,15 +17,15 @@ //! Math function: `isnan()`. -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, TypeSignature}; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct IsNanFunc { @@ -43,7 +43,10 @@ impl IsNanFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], + vec![ + TypeSignature::Exact(vec![Float32]), + TypeSignature::Exact(vec![Float64]), + ], Volatility::Immutable, ), } @@ -70,20 +73,15 @@ impl ScalarUDFImpl for IsNanFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - BooleanArray, - { f64::is_nan } - )), - DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - BooleanArray, - { f32::is_nan } - )), + DataType::Float64 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f64::is_nan, + )) as ArrayRef, + + DataType::Float32 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f32::is_nan, + )) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -93,4 +91,24 @@ impl ScalarUDFImpl for IsNanFunc { }; Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_isnan_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_isnan_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns true if a given number is +NaN or -NaN otherwise returns false.", + ) + .with_syntax_example("isnan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index d81a690843b63..cfd21256dd961 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -16,18 +16,19 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::make_scalar_function; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct NanvlFunc { @@ -75,6 +76,28 @@ impl ScalarUDFImpl for NanvlFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(nanvl, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nanvl_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nanvl_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns the first argument if it's not _NaN_. +Returns the second argument otherwise."#, + ) + .with_syntax_example("nanvl(expression_x, expression_y)") + .with_argument("expression_x", "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators.") + .with_argument("expression_y", "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators.") + .build() + .unwrap() + }) } /// Nanvl SQL function @@ -89,14 +112,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -107,14 +127,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -122,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index c2fe4efb1139c..ea0f331617726 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -16,12 +16,16 @@ // under the License. use std::any::Any; +use std::sync::OnceLock; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct PiFunc { @@ -73,4 +77,21 @@ impl ScalarUDFImpl for PiFunc { // This function returns a constant value. Ok(SortProperties::Singleton) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_pi_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_pi_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns an approximate value of π.") + .with_syntax_example("pi()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 5b790fb56ddf3..9bb6006d55b91 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,24 +16,22 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::{ArrowNativeTypeOp, DataType}; +use super::log::LogFunc; +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDF}; - -use arrow::array::{ArrayRef, Float64Array, Int64Array}; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use super::log::LogFunc; #[derive(Debug)] pub struct PowerFunc { @@ -52,7 +50,10 @@ impl PowerFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + vec![ + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), aliases: vec![String::from("pow")], @@ -87,15 +88,16 @@ impl ScalarUDFImpl for PowerFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ + } DataType::Int64 => { let bases = downcast_arg!(&args[0], "base", Int64Array); let exponents = downcast_arg!(&args[1], "exponent", Int64Array); @@ -113,7 +115,7 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(None), }) .collect::>() - .map(Arc::new)? as ArrayRef + .map(Arc::new)? as _ } other => { @@ -162,6 +164,27 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_power_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_power_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns a base expression raised to the power of an exponent.", + ) + .with_syntax_example("power(base, exponent)") + .with_standard_argument("base", Some("Numeric")) + .with_standard_argument("exponent", Some("Exponent numeric")) + .build() + .unwrap() + }) } /// Return true if this function call is a call to `Log` @@ -171,6 +194,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 20591a02a930d..cf564e5328a53 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; use arrow::datatypes::DataType; @@ -24,8 +24,9 @@ use arrow::datatypes::DataType::Float64; use rand::{thread_rng, Rng}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct RandomFunc { @@ -76,4 +77,24 @@ impl ScalarUDFImpl for RandomFunc { Ok(ColumnarValue::Array(Arc::new(array))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_random_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_random_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns a random float value in the range [0, 1). +The random seed is unique to each row."#, + ) + .with_syntax_example("random()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 89554a76febba..6000e5d765de1 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -16,20 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct RoundFunc { @@ -97,6 +98,28 @@ impl ScalarUDFImpl for RoundFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_round_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_round_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Rounds a number to the nearest integer.") + .with_syntax_example("round(numeric_expression[, decimal_places])") + .with_standard_argument("numeric_expression", Some("Numeric")) + .with_argument( + "decimal_places", + "Optional. The number of decimal places to round to. Defaults to 0.", + ) + .build() + .unwrap() + }) } /// Round SQL function @@ -115,7 +138,7 @@ pub fn round(args: &[ArrayRef]) -> Result { } match args[0].data_type() { - DataType::Float64 => match decimal_places { + Float64 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { exec_datafusion_err!( @@ -123,17 +146,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(decimal_places) => { let options = CastOptions { @@ -144,45 +163,38 @@ pub fn round(args: &[ArrayRef]) -> Result { .map_err(|e| { exec_datafusion_err!("Invalid values for decimal places: {e}") })?; - Ok(Arc::new(make_function_inputs2!( - &args[0], + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + values, decimal_places, - "value", - "decimal_places", - Float64Array, - Int32Array, - { - |value: f64, decimal_places: i32| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") } }, - DataType::Float32 => match decimal_places { + Float32 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { exec_datafusion_err!( "Invalid value for decimal places: {decimal_places}: {e}" ) })?; - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = @@ -193,20 +205,17 @@ pub fn round(args: &[ArrayRef]) -> Result { panic!("Unexpected result of ColumnarValue::Array.cast") }; - Ok(Arc::new(make_function_inputs2!( - &args[0], + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float32Array, - Int32Array, - { - |value: f32, decimal_places: i32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index d2a806a46e136..ac881eb42f269 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -16,16 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::make_scalar_function; @@ -81,42 +83,60 @@ impl ScalarUDFImpl for SignumFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(signum, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_signum_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_signum_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`."#, + ) + .with_syntax_example("signum(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } /// signum SQL function pub fn signum(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float64Array, - Float64Array, - { - |x: f64| { - if x == 0_f64 { - 0_f64 - } else { - x.signum() - } - } - } - )) as ArrayRef), - - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float32Array, - Float32Array, - { - |x: f32| { - if x == 0_f32 { - 0_f32 - } else { - x.signum() - } - } - } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>( + |x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), + + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>( + |x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function signum"), } diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 3344438454c4b..9a05684d238e7 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -16,18 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct TruncFunc { @@ -100,6 +103,31 @@ impl ScalarUDFImpl for TruncFunc { Ok(SortProperties::Unordered) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_trunc_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_trunc_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Truncates a number to a whole number or truncated to the specified decimal places.", + ) + .with_syntax_example("trunc(numeric_expression[, decimal_places])") + .with_standard_argument("numeric_expression", Some("Numeric")) + .with_argument("decimal_places", r#"Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`."#) + .build() + .unwrap() + }) } /// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function @@ -111,8 +139,8 @@ fn trunc(args: &[ArrayRef]) -> Result { ); } - //if only one arg then invoke toolchain trunc(num) and precision = 0 by default - //or then invoke the compute_truncate method to process precision + // If only one arg then invoke toolchain trunc(num) and precision = 0 by default + // or then invoke the compute_truncate method to process precision let num = &args[0]; let precision = if args.len() == 1 { ColumnarValue::Scalar(Int64(Some(0))) @@ -120,35 +148,57 @@ fn trunc(args: &[ArrayRef]) -> Result { ColumnarValue::Array(Arc::clone(&args[1])) }; - match args[0].data_type() { + match num.data_type() { Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float64Array, - Int64Array, - { compute_truncate64 } - )) as ArrayRef), + ColumnarValue::Scalar(Int64(Some(0))) => { + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.trunc() + } + }), + ) as ArrayRef) + } + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate64(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } _ => exec_err!("trunc function requires a scalar or array for precision"), }, Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float32Array, - Int64Array, - { compute_truncate32 } - )) as ArrayRef), + ColumnarValue::Scalar(Int64(Some(0))) => { + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.trunc() + } + }), + ) as ArrayRef) + } + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate32(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } _ => exec_err!("trunc function requires a scalar or array for precision"), }, other => exec_err!("Unsupported data type {other:?} for function trunc"), diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index ad42c5edd6e60..93edec7ece307 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -24,7 +24,7 @@ use datafusion_expr::{ Expr, }; -#[derive(Default)] +#[derive(Default, Debug)] pub struct UserDefinedFunctionPlanner; impl ExprPlanner for UserDefinedFunctionPlanner { diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4afbe6cbbb89c..803f51e915a9b 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,11 +17,15 @@ //! "regex" DataFusion functions +use std::sync::Arc; + +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -33,6 +37,24 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -67,6 +89,11 @@ pub mod expr_fn { } /// Returns all DataFusion functions defined in this package -pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] +pub fn functions() -> Vec> { + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 0000000000000..7f7896ecd923d --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,951 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::strings::StringArrayType; +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.") + .with_syntax_example("regexp_count(str, regexp[, start, flags])") + .with_sql_example(r#"```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp",Some("Regular")) + .with_argument("start", "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); + } + } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_count` function. +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; + + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } + } else { + (None, None, true) + }; + + let mut regex_cache = HashMap::new(); + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::, ArrowError>>()?, + ))) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + let start_array = start_array.unwrap(); + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::, ArrowError>>()?, + ))) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + } +} + +fn compile_and_cache_regex( + regex: &str, + flags: Option<&str>, + regex_cache: &mut HashMap, +) -> Result { + match regex_cache.entry(regex.to_string()) { + Entry::Vacant(entry) => { + let compiled = compile_regex(regex, flags)?; + entry.insert(compiled.clone()); + Ok(compiled) + } + Entry::Occupied(entry) => Ok(entry.get().to_owned()), + } +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{GenericStringArray, StringViewArray}; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_scalar_start(); + test_case_insensitive_regexp_count_scalar_flags(); + test_case_sensitive_regexp_count_start_scalar_complex(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let expected: Vec = vec![0, 1, 2, 1, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 2; + let expected: Vec = vec![0, 1, 1, 0, 2]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 1; + let flags = "i"; + let expected: Vec = vec![0, 1, 2, 2, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = ["", "abc", "a", "bc", "ab"]; + let start = 5; + let flags = ["", "i", "", "", "i"]; + let expected: Vec = vec![0, 0, 0, 1, 1]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 20029ba005c49..13de7888aa5fe 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -15,42 +15,95 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions -use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +//! Regex expressions + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, -}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpLikeFunc { signature: Signature, } + impl Default for RegexpLikeFunc { fn default() -> Self { Self::new() } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_like_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.") + .with_syntax_example("regexp_like(str, regexp[, flags])") + .with_sql_example(r#"```sql +select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); ++--------------------------------------------------------+ +| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | ++--------------------------------------------------------+ +| true | ++--------------------------------------------------------+ +SELECT regexp_like('aBc', '(b|d)', 'i'); ++--------------------------------------------------+ +| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp", Some("Regular")) + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + impl RegexpLikeFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8View]), + TypeSignature::Exact(vec![Utf8, LargeUtf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), ], Volatility::Immutable, ), @@ -81,6 +134,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { _ => Boolean, }) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let len = args .iter() @@ -96,7 +150,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - let result = regexp_like_func(&args); + let result = regexp_like(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -105,16 +159,12 @@ impl ScalarUDFImpl for RegexpLikeFunc { result.map(ColumnarValue::Array) } } -} -fn regexp_like_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => regexp_like::(args), - DataType::LargeUtf8 => regexp_like::(args), - other => { - internal_err!("Unsupported data type {other:?} for function regexp_like") - } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_like_doc()) } } + /// Tests a string using a regular expression returning true if at /// least one match, false otherwise. /// @@ -157,46 +207,114 @@ fn regexp_like_func(args: &[ArrayRef]) -> Result { /// # Ok(()) /// # } /// ``` -pub fn regexp_like(args: &[ArrayRef]) -> Result { +pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { - 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let array = regexp::regexp_is_match_utf8(values, regex, None) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; + let flags = args[2].as_string::(); if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); } - let array = regexp::regexp_is_match_utf8(values, regex, Some(flags)) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + handle_regexp_like(&args[0], &args[1], Some(flags)) + }, other => exec_err!( - "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." + "`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3." ), } } + +fn handle_regexp_like( + values: &ArrayRef, + patterns: &ArrayRef, + flags: Option<&GenericStringArray>, +) -> Result { + let array = match (values.data_type(), patterns.data_type()) { + (Utf8View, Utf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, Utf8View) => { + let value = values.as_string_view(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, LargeUtf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ) + } + }; + + Ok(Arc::new(array) as ArrayRef) +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::BooleanBuilder; use arrow::array::StringArray; + use arrow::array::{BooleanBuilder, StringViewArray}; use crate::regex::regexplike::regexp_like; #[test] - fn test_case_sensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = @@ -210,13 +328,33 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); assert_eq!(re.as_ref(), &expected); } #[test] - fn test_case_insensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); @@ -230,9 +368,29 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + let patterns = + StringViewArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -244,7 +402,7 @@ mod tests { let flags = StringArray::from(vec!["g"]); let re_err = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); assert_eq!( diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index bf40eff11d30f..019666bd7b2d4 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -26,11 +26,11 @@ use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpMatchFunc { @@ -53,10 +53,10 @@ impl RegexpMatchFunc { // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. // If that fails, it proceeds to `(LargeUtf8, Utf8)`. // TODO: Native support Utf8View for regexp_match. - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), @@ -107,7 +107,51 @@ impl ScalarUDFImpl for RegexpMatchFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_match_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_match_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") + .with_syntax_example("regexp_match(str, regexp[, flags])") + .with_sql_example(r#"```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) } + fn regexp_match_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => regexp_match::(args), @@ -131,7 +175,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { let flags = as_generic_string_array::(&args[2])?; if flags.iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option") + return plan_err!("regexp_match() does not support the \"global\" option"); } regexp::regexp_match(values, regex, Some(flags)) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 3eb72a1fb5f5e..4d8e5e5fe3e3b 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -32,14 +32,15 @@ use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::TypeSignature; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use regex::Regex; use std::any::Any; use std::collections::HashMap; -use std::sync::Arc; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; + #[derive(Debug)] pub struct RegexpReplaceFunc { signature: Signature, @@ -56,10 +57,10 @@ impl RegexpReplaceFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![Utf8View, Utf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8, Utf8]), - Exact(vec![Utf8View, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -123,6 +124,51 @@ impl ScalarUDFImpl for RegexpReplaceFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax).") + .with_syntax_example("regexp_replace(str, regexp, replacement[, flags])") + .with_sql_example(r#"```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_standard_argument("replacement", Some("Replacement string")) + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() +}) } fn regexp_replace_func(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/functions/src/regexp_common.rs b/datafusion/functions/src/regexp_common.rs deleted file mode 100644 index 748c1a294f972..0000000000000 --- a/datafusion/functions/src/regexp_common.rs +++ /dev/null @@ -1,123 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Common utilities for implementing regex functions - -use crate::string::common::StringArrayType; - -use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; -use arrow::datatypes::DataType; -use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; -use datafusion_common::DataFusionError; -use regex::Regex; - -use std::collections::HashMap; - -#[cfg(doc)] -use arrow::array::{LargeStringArray, StringArray, StringViewArray}; -/// Perform SQL `array ~ regex_array` operation on -/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. -/// -/// If `regex_array` element has an empty value, the corresponding result value is always true. -/// -/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, -/// which allow special search modes, such as case-insensitive and multi-line mode. -/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) -/// for more information. -/// -/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. -/// -/// Can remove when is implemented upstream -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match_utf8<'a, S1, S2, S3>( - array: &'a S1, - regex_array: &'a S2, - flags_array: Option<&'a S3>, -) -> datafusion_common::Result -where - &'a S1: StringArrayType<'a>, - &'a S2: StringArrayType<'a>, - &'a S3: StringArrayType<'a>, -{ - if array.len() != regex_array.len() { - return Err(DataFusionError::Execution( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); - - let mut patterns: HashMap = HashMap::new(); - let mut result = BooleanBufferBuilder::new(array.len()); - - let complete_pattern = match flags_array { - Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( - |(pattern, flags)| { - pattern.map(|pattern| match flags { - Some(flag) => format!("(?{flag}){pattern}"), - None => pattern.to_string(), - }) - }, - )) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; - - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - (Some(_), Some(pattern)) if pattern == *"" => { - result.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - DataFusionError::Execution(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - result.append(re.is_match(value)); - } - _ => result.append(false), - } - Ok(()) - }) - .collect::, DataFusionError>>()?; - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(array.len()) - .buffers(vec![result.into()]) - .nulls(nulls) - .build_unchecked() - }; - - Ok(BooleanArray::from(data)) -} diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 68ba3f5ff15f5..b76d70d7e9d26 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -20,10 +20,11 @@ use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct AsciiFunc { @@ -38,13 +39,8 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -71,6 +67,43 @@ impl ScalarUDFImpl for AsciiFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(ascii, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ascii_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ascii_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the Unicode character code of the first character in a string.", + ) + .with_syntax_example("ascii(str)") + .with_sql_example( + r#"```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("chr") + .build() + .unwrap() + }) } fn calculate_ascii<'a, V>(array: V) -> Result diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 65ec1a4a77346..25b56341fcaa3 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct BitLengthFunc { signature: Signature, @@ -39,13 +39,8 @@ impl Default for BitLengthFunc { impl BitLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -88,4 +83,34 @@ impl ScalarUDFImpl for BitLengthFunc { }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bit_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_bit_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the bit length of a string.") + .with_syntax_example("bit_length(str)") + .with_sql_example( + r#"```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 0e992ff27fd3b..e215b18d9c3ce 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use std::any::Any; - use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; - -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' @@ -49,18 +49,9 @@ impl Default for BTrimFunc { impl BTrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), aliases: vec![String::from("trim")], @@ -109,6 +100,37 @@ impl ScalarUDFImpl for BTrimFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_btrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_btrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.") + .with_syntax_example("btrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(BOTH trim_str FROM str)") + .with_alternative_syntax("trim(trim_str FROM str)") + .with_related_udf("ltrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 4da7dc01594d1..0d94cab08d913 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::array::StringArray; @@ -24,13 +24,13 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::make_scalar_function; - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -99,4 +99,35 @@ impl ScalarUDFImpl for ChrFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(chr, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_chr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_chr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the character with the specified ASCII or Unicode code value.", + ) + .with_syntax_example("chr(expression)") + .with_sql_example( + r#"```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .with_related_udf("ascii") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 72447bc68f4f4..0d1f90eb22b94 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -20,12 +20,12 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; +use crate::strings::make_and_append_view; use arrow::array::{ - make_view, new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, - ArrayRef, ByteView, GenericStringArray, GenericStringBuilder, LargeStringArray, - OffsetSizeTrait, StringArray, StringBuilder, StringViewArray, StringViewBuilder, + new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, StringBuilder, StringViewArray, }; -use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; +use arrow::buffer::Buffer; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; @@ -33,42 +33,6 @@ use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::ColumnarValue; -/// Append a new view to the views buffer with the given substr -/// -/// # Safety -/// -/// original_view must be a valid view (the format described on -/// [`GenericByteViewArray`](arrow::array::GenericByteViewArray). -/// -/// # Arguments -/// - views_buffer: The buffer to append the new view to -/// - null_builder: The buffer to append the null value to -/// - original_view: The original view value -/// - substr: The substring to append. Must be a valid substring of the original view -/// - start_offset: The start offset of the substring in the view -pub(crate) fn make_and_append_view( - views_buffer: &mut Vec, - null_builder: &mut NullBufferBuilder, - original_view: &u128, - substr: &str, - start_offset: u32, -) { - let substr_len = substr.len(); - let sub_view = if substr_len > 12 { - let view = ByteView::from(*original_view); - make_view( - substr.as_bytes(), - view.buffer_index, - view.offset + start_offset, - ) - } else { - // inline value does not need block id or offset - make_view(substr.as_bytes(), 0, 0) - }; - views_buffer.push(sub_view); - null_builder.append_non_null(); -} - pub(crate) enum TrimType { Left, Right, @@ -399,370 +363,6 @@ where } } -#[derive(Debug)] -pub(crate) enum ColumnarValueRef<'a> { - Scalar(&'a [u8]), - NullableArray(&'a StringArray), - NonNullableArray(&'a StringArray), - NullableLargeStringArray(&'a LargeStringArray), - NonNullableLargeStringArray(&'a LargeStringArray), - NullableStringViewArray(&'a StringViewArray), - NonNullableStringViewArray(&'a StringViewArray), -} - -impl<'a> ColumnarValueRef<'a> { - #[inline] - pub fn is_valid(&self, i: usize) -> bool { - match &self { - Self::Scalar(_) - | Self::NonNullableArray(_) - | Self::NonNullableLargeStringArray(_) - | Self::NonNullableStringViewArray(_) => true, - Self::NullableArray(array) => array.is_valid(i), - Self::NullableStringViewArray(array) => array.is_valid(i), - Self::NullableLargeStringArray(array) => array.is_valid(i), - } - } - - #[inline] - pub fn nulls(&self) -> Option { - match &self { - Self::Scalar(_) - | Self::NonNullableArray(_) - | Self::NonNullableStringViewArray(_) - | Self::NonNullableLargeStringArray(_) => None, - Self::NullableArray(array) => array.nulls().cloned(), - Self::NullableStringViewArray(array) => array.nulls().cloned(), - Self::NullableLargeStringArray(array) => array.nulls().cloned(), - } - } -} - -/// Abstracts iteration over different types of string arrays. -/// -/// The [`StringArrayType`] trait helps write generic code for string functions that can work with -/// different types of string arrays. -/// -/// Currently three types are supported: -/// - [`StringArray`] -/// - [`LargeStringArray`] -/// - [`StringViewArray`] -/// -/// It is inspired / copied from [arrow-rs]. -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/bf0ea9129e617e4a3cf915a900b747cc5485315f/arrow-string/src/like.rs#L151-L157 -/// -/// # Examples -/// Generic function that works for [`StringArray`], [`LargeStringArray`] -/// and [`StringViewArray`]: -/// ``` -/// # use arrow::array::{StringArray, LargeStringArray, StringViewArray}; -/// # use datafusion_functions::string::common::StringArrayType; -/// -/// /// Combines string values for any StringArrayType type. It can be invoked on -/// /// and combination of `StringArray`, `LargeStringArray` or `StringViewArray` -/// fn combine_values<'a, S1, S2>(array1: S1, array2: S2) -> Vec -/// where S1: StringArrayType<'a>, S2: StringArrayType<'a> -/// { -/// // iterate over the elements of the 2 arrays in parallel -/// array1 -/// .iter() -/// .zip(array2.iter()) -/// .map(|(s1, s2)| { -/// // if both values are non null, combine them -/// if let (Some(s1), Some(s2)) = (s1, s2) { -/// format!("{s1}{s2}") -/// } else { -/// "None".to_string() -/// } -/// }) -/// .collect() -/// } -/// -/// let string_array = StringArray::from(vec!["foo", "bar"]); -/// let large_string_array = LargeStringArray::from(vec!["foo2", "bar2"]); -/// let string_view_array = StringViewArray::from(vec!["foo3", "bar3"]); -/// -/// // can invoke this function a string array and large string array -/// assert_eq!( -/// combine_values(&string_array, &large_string_array), -/// vec![String::from("foofoo2"), String::from("barbar2")] -/// ); -/// -/// // Can call the same function with string array and string view array -/// assert_eq!( -/// combine_values(&string_array, &string_view_array), -/// vec![String::from("foofoo3"), String::from("barbar3")] -/// ); -/// ``` -/// -/// [`LargeStringArray`]: arrow::array::LargeStringArray -pub trait StringArrayType<'a>: ArrayAccessor + Sized { - /// Return an [`ArrayIter`] over the values of the array. - /// - /// This iterator iterates returns `Option<&str>` for each item in the array. - fn iter(&self) -> ArrayIter; - - /// Check if the array is ASCII only. - fn is_ascii(&self) -> bool; -} - -impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { - fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) - } - - fn is_ascii(&self) -> bool { - GenericStringArray::::is_ascii(self) - } -} - -impl<'a> StringArrayType<'a> for &'a StringViewArray { - fn iter(&self) -> ArrayIter { - StringViewArray::iter(self) - } - - fn is_ascii(&self) -> bool { - StringViewArray::is_ascii(self) - } -} - -/// Optimized version of the StringBuilder in Arrow that: -/// 1. Precalculating the expected length of the result, avoiding reallocations. -/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` -pub(crate) struct StringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl StringArrayBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i32) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - pub fn append_offset(&mut self) { - let next_offset: i32 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - pub fn finish(self, null_buffer: Option) -> StringArray { - let array_builder = ArrayDataBuilder::new(DataType::Utf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - StringArray::from(array_data) - } -} - -pub(crate) struct StringViewArrayBuilder { - builder: StringViewBuilder, - block: String, -} - -impl StringViewArrayBuilder { - pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { - let builder = StringViewBuilder::with_capacity(data_capacity); - Self { - builder, - block: String::new(), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.block.push_str(std::str::from_utf8(s).unwrap()); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); - } - } - } - - pub fn append_offset(&mut self) { - self.builder.append_value(&self.block); - self.block = String::new(); - } - - pub fn finish(mut self) -> StringViewArray { - self.builder.finish() - } -} - -pub(crate) struct LargeStringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl LargeStringArrayBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i64) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - pub fn append_offset(&mut self) { - let next_offset: i64 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - pub fn finish(self, null_buffer: Option) -> LargeStringArray { - let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid Large UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - LargeStringArray::from(array_data) - } -} - fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 98f57efef90d4..a4218c39e7b28 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -18,18 +18,20 @@ use arrow::array::{as_largestring_array, Array}; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::string::concat; +use crate::strings::{ + ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, +}; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::string::concat; - #[derive(Debug)] pub struct ConcatFunc { signature: Signature, @@ -244,6 +246,36 @@ impl ScalarUDFImpl for ConcatFunc { ) -> Result { simplify_concat(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Concatenates multiple strings together.") + .with_syntax_example("concat(str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("str_n", "Subsequent string expressions to concatenate.") + .with_related_udf("concat_ws") + .build() + .unwrap() + }) } pub fn simplify_concat(args: Vec) -> Result { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 1134c525cfca4..8d966f4956630 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -17,18 +17,19 @@ use arrow::array::{as_largestring_array, Array, StringArray}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; -use crate::string::common::*; use crate::string::concat::simplify_concat; use crate::string::concat_ws; +use crate::strings::{ColumnarValueRef, StringArrayBuilder}; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -264,6 +265,42 @@ impl ScalarUDFImpl for ConcatWsFunc { _ => Ok(ExprSimplifyResult::Original(args)), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_ws_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_ws_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Concatenates multiple strings together with a specified separator.", + ) + .with_syntax_example("concat_ws(separator, str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +```"#, + ) + .with_argument( + "separator", + "Separator to insert between concatenated strings.", + ) + .with_standard_argument("str", Some("String")) + .with_argument("str_n", "Subsequent string expressions to concatenate.") + .with_related_udf("concat") + .build() + .unwrap() + }) } fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index c319f80661c3a..d0e63bb0f353f 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::regexp_common::regexp_is_match_utf8; use crate::utils::make_scalar_function; - -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_expr::ScalarUDFImpl; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, Signature, Volatility}; - +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct ContainsFunc { @@ -44,22 +43,8 @@ impl Default for ContainsFunc { impl ContainsFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8View, Utf8]), - Exact(vec![Utf8View, LargeUtf8]), - Exact(vec![Utf8, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8View]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -84,108 +69,58 @@ impl ScalarUDFImpl for ContainsFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(contains, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_contains_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_contains_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Return true if search_str is found within string (case-sensitive).", + ) + .with_syntax_example("contains(str, search_str)") + .with_sql_example( + r#"```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("search_str", "The string to search for in str.") + .build() + .unwrap() + }) } -/// use regexp_is_match_utf8_scalar to do the calculation for contains +/// use `arrow::compute::contains` to do the calculation for contains pub fn contains(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - StringViewArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8View, Utf8) => { - let mod_str = args[0].as_string_view(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - StringViewArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8View, LargeUtf8) => { - let mod_str = args[0].as_string_view(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - StringViewArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8, Utf8View) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - GenericStringArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8, LargeUtf8) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (LargeUtf8, Utf8View) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - GenericStringArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (LargeUtf8, Utf8) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } other => { @@ -195,93 +130,29 @@ pub fn contains(args: &[ArrayRef]) -> Result { } #[cfg(test)] -mod tests { - use crate::string::contains::ContainsFunc; - use crate::utils::test::test_function; - use arrow::array::Array; - use arrow::{array::BooleanArray, datatypes::DataType::Boolean}; - use datafusion_common::Result; +mod test { + use super::ContainsFunc; + use arrow::array::{BooleanArray, StringArray}; use datafusion_common::ScalarValue; - use datafusion_expr::ColumnarValue; - use datafusion_expr::ScalarUDFImpl; - #[test] - fn test_functions() -> Result<()> { - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("dddddd")), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("pha")), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( - "DataFusion" - )))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray + #[test] + fn test_contains_udf() { + let udf = ContainsFunc::new(); + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("xxx?()"), + Some("yyy?()"), + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let actual = udf.invoke(&[array, scalar]).unwrap(); + let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + ]))); + assert_eq!( + *actual.into_array(2).unwrap(), + *expect.into_array(2).unwrap() ); - - Ok(()) } } diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 03a1795954d03..88978a35c0b7f 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -16,18 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; +use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct EndsWithFunc { signature: Signature, @@ -42,17 +41,7 @@ impl Default for EndsWithFunc { impl EndsWithFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -84,6 +73,41 @@ impl ScalarUDFImpl for EndsWithFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ends_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ends_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string ends with a substring.") + .with_syntax_example("ends_with(str, substr)") + .with_sql_example( + r#"```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) } /// Returns true if string ends with suffix. diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 4e1eb213ef57d..5fd1e7929881a 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -16,18 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct InitcapFunc { signature: Signature, @@ -41,13 +41,8 @@ impl Default for InitcapFunc { impl InitcapFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -79,6 +74,34 @@ impl ScalarUDFImpl for InitcapFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_initcap_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_initcap_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters.") + .with_syntax_example("initcap(str)") + .with_sql_example(r#"```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_related_udf("lower") + .with_related_udf("upper") + .build() + .unwrap() + }) } /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 430c402a50c54..558e71239f84e 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -25,8 +25,8 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -43,14 +43,7 @@ impl Default for LevenshteinFunc { impl LevenshteinFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -83,6 +76,33 @@ impl ScalarUDFImpl for LevenshteinFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_levenshtein_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_levenshtein_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.") + .with_syntax_example("levenshtein(str1, str2)") + .with_sql_example(r#"```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +```"#) + .with_argument("str1", "String expression to compute Levenshtein distance with str2.") + .with_argument("str2", "String expression to compute Levenshtein distance with str1.") + .build() + .unwrap() + }) } ///Returns the Levenshtein distance between the two given strings. diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ca324e69c0d23..b07189a832dc8 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; - -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::OnceLock; use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; +use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct LowerFunc { @@ -39,13 +39,8 @@ impl Default for LowerFunc { impl LowerFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -70,8 +65,37 @@ impl ScalarUDFImpl for LowerFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_lower(args, "lower") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lower_doc()) + } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lower_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to lower-case.") + .with_syntax_example("lower(str)") + .with_sql_example( + r#"```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("initcap") + .with_related_udf("upper") + .build() + .unwrap() + }) +} #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 0ddb5a205baca..0b4c197646b6d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { @@ -49,18 +48,9 @@ impl Default for LtrimFunc { impl LtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -104,6 +94,42 @@ impl ScalarUDFImpl for LtrimFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ltrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ltrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.") + .with_syntax_example("ltrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(LEADING trim_str FROM str)") + .with_related_udf("btrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f792914d862e4..2ac2bf70da231 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct OctetLengthFunc { signature: Signature, @@ -39,13 +39,8 @@ impl Default for OctetLengthFunc { impl OctetLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -91,6 +86,36 @@ impl ScalarUDFImpl for OctetLengthFunc { }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_octet_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_octet_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the length of a string in bytes.") + .with_syntax_example("octet_length(str)") + .with_sql_example( + r#"```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("bit_length") + .with_related_udf("length") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index e285bd85b197b..796776304f4ae 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -16,21 +16,20 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct OverlayFunc { signature: Signature, @@ -48,12 +47,12 @@ impl OverlayFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8View, Utf8View, Int64, Int64]), - Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), - Exact(vec![Utf8View, Utf8View, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -87,6 +86,35 @@ impl ScalarUDFImpl for OverlayFunc { other => exec_err!("Unsupported data type {other:?} for function overlay"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_overlay_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_overlay_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the string which is replaced by another string from the specified position and specified count length.") + .with_syntax_example("overlay(str PLACING substr FROM pos [FOR count])") + .with_sql_example(r#"```sql +> select overlay('Txxxxas' placing 'hom' from 2 for 4); ++--------------------------------------------------------+ +| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) | ++--------------------------------------------------------+ +| Thomas | ++--------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to replace in str.") + .with_argument("pos", "The start position to start the replace in str.") + .with_argument("count", "The count of characters to be replaced from start position of str. If not specified, will use substr length instead.") + .build() + .unwrap() + }) } macro_rules! process_overlay { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 20e4462784b82..aa69f9c6609ad 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -16,24 +16,22 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::strings::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, LargeUtf8, Utf8, Utf8View}; - use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::StringArrayType; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct RepeatFunc { signature: Signature, @@ -53,9 +51,9 @@ impl RepeatFunc { // Planner attempts coercion to the target type starting with the most preferred candidate. // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. // If that fails, it proceeds to `(Utf8, Int64)`. - Exact(vec![Utf8View, Int64]), - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -83,6 +81,37 @@ impl ScalarUDFImpl for RepeatFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(repeat, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_repeat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_repeat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns a string with an input string repeated a specified number.", + ) + .with_syntax_example("repeat(str, n)") + .with_sql_example( + r#"```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of times to repeat the input string.") + .build() + .unwrap() + }) } /// Repeats string the specified number of times. diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 13fa3d55672dd..91abc39da058a 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -16,19 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct ReplaceFunc { signature: Signature, @@ -42,16 +41,8 @@ impl Default for ReplaceFunc { impl ReplaceFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8View, Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(3, Volatility::Immutable), } } } @@ -83,6 +74,34 @@ impl ScalarUDFImpl for ReplaceFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Replaces all occurrences of a specified substring in a string with a new substring.") + .with_syntax_example("replace(str, substr, replacement)") + .with_sql_example(r#"```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("substr", Some("Substring expression to replace in the input string. Substring")) + .with_standard_argument("replacement", Some("Replacement substring")) + .build() + .unwrap() + }) } fn replace_view(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index a1aa5568babb0..e934147efbbe3 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -16,19 +16,18 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use std::any::Any; - use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { @@ -49,18 +48,9 @@ impl Default for RtrimFunc { impl RtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -104,6 +94,42 @@ impl ScalarUDFImpl for RtrimFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_rtrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rtrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.") + .with_syntax_example("rtrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(TRAILING trim_str FROM str)") + .with_related_udf("btrim") + .with_related_udf("ltrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 8d292315a35ac..ea01cb1f56f9a 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::strings::StringArrayType; +use crate::utils::utf8_to_str_type; use arrow::array::{ ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray, }; @@ -23,15 +25,11 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use std::any::Any; -use std::sync::Arc; - -use crate::utils::utf8_to_str_type; - -use super::common::StringArrayType; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct SplitPartFunc { @@ -50,15 +48,15 @@ impl SplitPartFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8View, Utf8View, Int64]), - Exact(vec![Utf8View, Utf8, Int64]), - Exact(vec![Utf8View, LargeUtf8, Int64]), - Exact(vec![Utf8, Utf8View, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8View, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -178,6 +176,34 @@ impl ScalarUDFImpl for SplitPartFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_split_part_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_split_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Splits a string based on a specified delimiter and returns the substring in the specified position.") + .with_syntax_example("split_part(str, delimiter, pos)") + .with_sql_example(r#"```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("delimiter", "String or character to split on.") + .with_argument("pos", "Position of the part to return.") + .build() + .unwrap() + }) } /// impl diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 8450697cbf303..dce161a2e14bd 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -16,18 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; +use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { @@ -49,17 +48,7 @@ impl Default for StartsWithFunc { impl StartsWithFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -89,6 +78,35 @@ impl ScalarUDFImpl for StartsWithFunc { _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_starts_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_starts_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string starts with a substring.") + .with_syntax_example("starts_with(str, substr)") + .with_sql_example( + r#"```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 79aa9254f9b16..e0033d2d1cb03 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,21 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result @@ -110,6 +110,34 @@ impl ScalarUDFImpl for ToHexFunc { other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_hex_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_hex_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts an integer to a hexadecimal string.") + .with_syntax_example("to_hex(int)") + .with_sql_example( + r#"```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +```"#, + ) + .with_standard_argument("int", Some("Integer")) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 593e33ab6bb48..042c26b2e3daf 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -19,9 +19,11 @@ use crate::string::common::to_upper; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct UpperFunc { @@ -36,13 +38,8 @@ impl Default for UpperFunc { impl UpperFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -67,6 +64,36 @@ impl ScalarUDFImpl for UpperFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_upper(args, "upper") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_upper_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_upper_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to upper-case.") + .with_syntax_example("upper(str)") + .with_sql_example( + r#"```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("initcap") + .with_related_udf("lower") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 3ddc320fcec1f..0fbdce16ccd13 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::GenericStringArray; use arrow::datatypes::DataType; @@ -24,7 +24,8 @@ use arrow::datatypes::DataType::Utf8; use uuid::Uuid; use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -74,4 +75,29 @@ impl ScalarUDFImpl for UuidFunc { let array = GenericStringArray::::from_iter_values(values); Ok(ColumnarValue::Array(Arc::new(array))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_uuid_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_uuid_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.") + .with_syntax_example("uuid()") + .with_sql_example(r#"```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +```"#) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs new file mode 100644 index 0000000000000..e0cec3cb5756f --- /dev/null +++ b/datafusion/functions/src/strings.rs @@ -0,0 +1,424 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use arrow::array::{ + make_view, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ByteView, + GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, + StringViewBuilder, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{MutableBuffer, NullBuffer, NullBufferBuilder}; + +/// Abstracts iteration over different types of string arrays. +/// +/// The [`StringArrayType`] trait helps write generic code for string functions that can work with +/// different types of string arrays. +/// +/// Currently three types are supported: +/// - [`StringArray`] +/// - [`LargeStringArray`] +/// - [`StringViewArray`] +/// +/// It is inspired / copied from [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/bf0ea9129e617e4a3cf915a900b747cc5485315f/arrow-string/src/like.rs#L151-L157 +/// +/// # Examples +/// Generic function that works for [`StringArray`], [`LargeStringArray`] +/// and [`StringViewArray`]: +/// ``` +/// # use arrow::array::{StringArray, LargeStringArray, StringViewArray}; +/// # use datafusion_functions::strings::StringArrayType; +/// +/// /// Combines string values for any StringArrayType type. It can be invoked on +/// /// and combination of `StringArray`, `LargeStringArray` or `StringViewArray` +/// fn combine_values<'a, S1, S2>(array1: S1, array2: S2) -> Vec +/// where S1: StringArrayType<'a>, S2: StringArrayType<'a> +/// { +/// // iterate over the elements of the 2 arrays in parallel +/// array1 +/// .iter() +/// .zip(array2.iter()) +/// .map(|(s1, s2)| { +/// // if both values are non null, combine them +/// if let (Some(s1), Some(s2)) = (s1, s2) { +/// format!("{s1}{s2}") +/// } else { +/// "None".to_string() +/// } +/// }) +/// .collect() +/// } +/// +/// let string_array = StringArray::from(vec!["foo", "bar"]); +/// let large_string_array = LargeStringArray::from(vec!["foo2", "bar2"]); +/// let string_view_array = StringViewArray::from(vec!["foo3", "bar3"]); +/// +/// // can invoke this function a string array and large string array +/// assert_eq!( +/// combine_values(&string_array, &large_string_array), +/// vec![String::from("foofoo2"), String::from("barbar2")] +/// ); +/// +/// // Can call the same function with string array and string view array +/// assert_eq!( +/// combine_values(&string_array, &string_view_array), +/// vec![String::from("foofoo3"), String::from("barbar3")] +/// ); +/// ``` +/// +/// [`LargeStringArray`]: arrow::array::LargeStringArray +pub trait StringArrayType<'a>: ArrayAccessor + Sized { + /// Return an [`ArrayIter`] over the values of the array. + /// + /// This iterator iterates returns `Option<&str>` for each item in the array. + fn iter(&self) -> ArrayIter; + + /// Check if the array is ASCII only. + fn is_ascii(&self) -> bool; +} + +impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } + + fn is_ascii(&self) -> bool { + GenericStringArray::::is_ascii(self) + } +} + +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) + } + + fn is_ascii(&self) -> bool { + StringViewArray::is_ascii(self) + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +pub struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> StringArray { + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(self.offsets_buffer.len() / size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + +pub struct StringViewArrayBuilder { + builder: StringViewBuilder, + block: String, +} + +impl StringViewArrayBuilder { + pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { + let builder = StringViewBuilder::with_capacity(data_capacity); + Self { + builder, + block: String::new(), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.block.push_str(std::str::from_utf8(s).unwrap()); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + } + } + + pub fn append_offset(&mut self) { + self.builder.append_value(&self.block); + self.block = String::new(); + } + + pub fn finish(mut self) -> StringViewArray { + self.builder.finish() + } +} + +pub struct LargeStringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl LargeStringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i64) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i64 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> LargeStringArray { + let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) + .len(self.offsets_buffer.len() / size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid Large UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + LargeStringArray::from(array_data) + } +} + +/// Append a new view to the views buffer with the given substr +/// +/// # Safety +/// +/// original_view must be a valid view (the format described on +/// [`GenericByteViewArray`](arrow::array::GenericByteViewArray). +/// +/// # Arguments +/// - views_buffer: The buffer to append the new view to +/// - null_builder: The buffer to append the null value to +/// - original_view: The original view value +/// - substr: The substring to append. Must be a valid substring of the original view +/// - start_offset: The start offset of the substring in the view +pub fn make_and_append_view( + views_buffer: &mut Vec, + null_builder: &mut NullBufferBuilder, + original_view: &u128, + substr: &str, + start_offset: u32, +) { + let substr_len = substr.len(); + let sub_view = if substr_len > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + // inline value does not need block id or offset + make_view(substr.as_bytes(), 0, 0) + }; + views_buffer.push(sub_view); + null_builder.append_non_null(); +} + +#[derive(Debug)] +pub enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), + NullableLargeStringArray(&'a LargeStringArray), + NonNullableLargeStringArray(&'a LargeStringArray), + NullableStringViewArray(&'a StringViewArray), + NonNullableStringViewArray(&'a StringViewArray), +} + +impl<'a> ColumnarValueRef<'a> { + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableLargeStringArray(_) + | Self::NonNullableStringViewArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + Self::NullableStringViewArray(array) => array.is_valid(i), + Self::NullableLargeStringArray(array) => array.is_valid(i), + } + } + + #[inline] + pub fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableStringViewArray(_) + | Self::NonNullableLargeStringArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + Self::NullableStringViewArray(array) => array.nulls().cloned(), + Self::NullableLargeStringArray(array) => array.nulls().cloned(), + } + } +} diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c9dc96b2a9350..7858a59664d3e 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -15,16 +15,19 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::StringArrayType; +use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct CharacterLengthFunc { @@ -76,6 +79,36 @@ impl ScalarUDFImpl for CharacterLengthFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_character_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_character_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the number of characters in a string.") + .with_syntax_example("character_length(str)") + .with_sql_example( + r#"```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("bit_length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } /// Returns number of characters in the string. diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 41a2b9d9e72de..cad860e41088f 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, @@ -24,11 +24,13 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_int_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FindInSetFunc { @@ -77,6 +79,33 @@ impl ScalarUDFImpl for FindInSetFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(find_in_set, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_find_in_set_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_find_in_set_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.") + .with_syntax_example("find_in_set(str, strlist)") + .with_sql_example(r#"```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +```"#) + .with_argument("str", "String expression to find in strlist.") + .with_argument("strlist", "A string list is a string composed of substrings separated by , characters.") + .build() + .unwrap() + }) } ///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index c49784948dd0d..a6c2b9768f0bc 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -25,15 +25,17 @@ use arrow::array::{ }; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LeftFunc { @@ -91,6 +93,34 @@ impl ScalarUDFImpl for LeftFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_left_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_left_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the left side of a string.") + .with_syntax_example("left(str, n)") + .with_sql_example(r#"```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of characters to return.") + .with_related_udf("right") + .build() + .unwrap() + }) } /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index e102673c42530..767eda203c8fe 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt::Write; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, @@ -27,13 +27,15 @@ use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; +use crate::strings::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::string::common::StringArrayType; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LPadFunc { @@ -95,6 +97,35 @@ impl ScalarUDFImpl for LPadFunc { other => exec_err!("Unsupported data type {other:?} for function lpad"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lpad_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the left side of a string with another string to a specified string length.") + .with_syntax_example("lpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("rpad") + .build() + .unwrap() + }) } /// Extends the string to length 'length' by prepending the characters fill (a space by default). diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index da16d3ee37520..baf3b56636e2e 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -16,19 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, }; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use DataType::{LargeUtf8, Utf8, Utf8View}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct ReverseFunc { signature: Signature, @@ -79,6 +81,34 @@ impl ScalarUDFImpl for ReverseFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_reverse_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_reverse_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Reverses the character order of a string.") + .with_syntax_example("reverse(str)") + .with_sql_example( + r#"```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .build() + .unwrap() + }) } /// Reverses the order of the characters in the string. diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 9d542bb2c0065..ab3b7ba1a27e6 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::{max, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -31,8 +31,11 @@ use datafusion_common::cast::{ }; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct RightFunc { @@ -90,6 +93,34 @@ impl ScalarUDFImpl for RightFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_right_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_right_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the right side of a string.") + .with_syntax_example("right(str, n)") + .with_sql_example(r#"```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of characters to return") + .with_related_udf("left") + .build() + .unwrap() + }) } /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index c1d6f327928f2..bd9d625105e9f 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::StringArrayType; +use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, @@ -25,11 +25,14 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; use std::fmt::Write; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; @@ -113,6 +116,39 @@ impl ScalarUDFImpl for RPadFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_rpad_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the right side of a string with another string to a specified string length.") + .with_syntax_example("rpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +```"#) + .with_standard_argument( + "str", + Some("String"), + ) + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", + "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("lpad") + .build() + .unwrap() + }) } pub fn rpad( diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 3879f779eb713..9c84590f7f94e 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -16,15 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::strings::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; - -use crate::string::common::StringArrayType; -use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct StrposFunc { @@ -41,7 +43,7 @@ impl Default for StrposFunc { impl StrposFunc { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::string(2, Volatility::Immutable), aliases: vec![String::from("instr"), String::from("position")], } } @@ -72,26 +74,35 @@ impl ScalarUDFImpl for StrposFunc { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - match arg_types { - [first, second ] => { - match (first, second) { - (DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => Ok(arg_types.to_vec()), - (DataType::Null, DataType::Null) => Ok(vec![DataType::Utf8, DataType::Utf8]), - (DataType::Null, _) => Ok(vec![DataType::Utf8, second.to_owned()]), - (_, DataType::Null) => Ok(vec![first.to_owned(), DataType::Utf8]), - (DataType::Dictionary(_, value_type), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => match **value_type { - DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 | DataType::Null | DataType::Binary => Ok(vec![*value_type.clone(), second.to_owned()]), - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", **value_type), - }, - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", arg_types) - } - }, - _ => plan_err!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}", arg_types) - } + fn documentation(&self) -> Option<&Documentation> { + Some(get_strpos_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_strpos_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.") + .with_syntax_example("strpos(str, substr)") + .with_sql_example(r#"```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring expression to search for.") + .with_alternative_syntax("position(substr in origstr)") + .build() + .unwrap() + }) +} + fn strpos(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (DataType::Utf8, DataType::Utf8) => { diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 205de0b30b9c1..edfe57210b711 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,9 +16,9 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use crate::string::common::{make_and_append_view, StringArrayType}; +use crate::strings::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, @@ -28,7 +28,10 @@ use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrFunc { @@ -81,6 +84,13 @@ impl ScalarUDFImpl for SubstrFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 || arg_types.len() > 3 { + return plan_err!( + "The {} function requires 2 or 3 arguments, but got {}.", + self.name(), + arg_types.len() + ); + } let first_data_type = match &arg_types[0] { DataType::Null => Ok(DataType::Utf8), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), @@ -138,6 +148,35 @@ impl ScalarUDFImpl for SubstrFunc { ]) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Extracts a substring of a specified number of characters from a specific starting position in a string.") + .with_syntax_example("substr(str, start_pos[, length])") + .with_sql_example(r#"```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") + .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") + .with_alternative_syntax("substring(str from start_pos for length)") + .build() + .unwrap() + }) } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 6591ee26403aa..c04839783f58f 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, @@ -24,11 +24,13 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrIndexFunc { @@ -83,6 +85,42 @@ impl ScalarUDFImpl for SubstrIndexFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_index_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_index_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description(r#"Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#) + .with_syntax_example("substr_index(str, delim, count)") + .with_sql_example(r#"```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("delim", "The string to find in str to split str.") + .with_argument("count", "The number of times to search for the delimiter. Can be either a positive or negative number.") + .build() + .unwrap() + }) } /// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index a42b9c6cb8578..fa626b396b3be 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, @@ -27,8 +27,11 @@ use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct TranslateFunc { @@ -76,6 +79,34 @@ impl ScalarUDFImpl for TranslateFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(invoke_translate, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_translate_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_translate_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Translates characters in a string to specified translation characters.") + .with_syntax_example("translate(str, chars, translation)") + .with_sql_example(r#"```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("chars", "Characters to translate.") + .with_argument("translation", "Translation characters. Translation characters replace only characters at the same position in the **chars** string.") + .build() + .unwrap() + }) } fn invoke_translate(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 818b4c64bd20c..4d6574d2bd6cd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -107,7 +107,7 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>>()?; + .collect::>>()?; let result = (inner)(&args); if is_scalar { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 337a24ffae206..79a5bb24e9187 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -35,10 +35,6 @@ workspace = true name = "datafusion_optimizer" path = "src/lib.rs" -[features] -default = ["regex_expressions"] -regex_expressions = ["datafusion-physical-expr/regex_expressions"] - [dependencies] arrow = { workspace = true } async-trait = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 86520b3587cdc..454afa24b628c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -48,13 +48,7 @@ impl AnalyzerRule for CountWildcardRule { } fn is_wildcard(expr: &Expr) -> bool { - matches!( - expr, - Expr::Wildcard { - qualifier: None, - .. - } - ) + matches!(expr, Expr::Wildcard { .. }) } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { @@ -107,7 +101,7 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + col, exists, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; @@ -225,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(expr::WindowFunction::new( + .window(vec![Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index a26ec4be5c851..9fbe54e1ccb92 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -26,7 +26,9 @@ use datafusion_expr::expr::PlannedReplaceSelectItem; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, find_base_plan, }; -use datafusion_expr::{Expr, LogicalPlan, Projection, SubqueryAlias}; +use datafusion_expr::{ + Distinct, DistinctOn, Expr, LogicalPlan, Projection, SubqueryAlias, +}; #[derive(Default, Debug)] pub struct ExpandWildcardRule {} @@ -59,12 +61,25 @@ fn expand_internal(plan: LogicalPlan) -> Result> { .map(LogicalPlan::Projection)?, )) } - // Teh schema of the plan should also be updated if the child plan is transformed. + // The schema of the plan should also be updated if the child plan is transformed. LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { Ok(Transformed::yes( SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias)?, )) } + LogicalPlan::Distinct(Distinct::On(distinct_on)) => { + let projected_expr = + expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; + validate_unique_names("Distinct", projected_expr.iter())?; + Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new( + distinct_on.on_expr, + projected_expr, + distinct_on.sort_expr, + distinct_on.input, + )?, + )))) + } _ => Ok(Transformed::no(plan)), } } @@ -240,6 +255,18 @@ mod tests { assert_plan_eq(plan, expected) } + #[test] + fn test_expand_wildcard_in_distinct_on() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![wildcard()], None)? + .build()?; + let expected = "\ + DistinctOn: on_expr=[[test.a]], select_expr=[[test.a, test.b, test.c]], sort_expr=[[]] [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + #[test] fn test_subquery_schema() -> Result<()> { let analyzer = Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]); diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 4cd891664e7f5..a9fd4900b2f4a 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan}; use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; +use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -44,6 +45,7 @@ pub mod count_wildcard_rule; pub mod expand_wildcard_rule; pub mod function_rewrite; pub mod inline_table_scan; +pub mod resolve_grouping_function; pub mod subquery; pub mod type_coercion; @@ -96,6 +98,7 @@ impl Analyzer { // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), // [Expr::Wildcard] should be expanded before [TypeCoercion] + Arc::new(ResolveGroupingFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs new file mode 100644 index 0000000000000..16ebb8cd3972f --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzed rule to replace TableScan references +//! such as DataFrames and Views and inlines the LogicalPlan. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::analyzer::AnalyzerRule; + +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, +}; +use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::grouping_set_to_exprlist; +use datafusion_expr::{ + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, + Expr, Projection, +}; +use itertools::Itertools; + +/// Replaces grouping aggregation function with value derived from internal grouping id +#[derive(Default, Debug)] +pub struct ResolveGroupingFunction; + +impl ResolveGroupingFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ResolveGroupingFunction { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + plan.transform_up(analyze_internal).data() + } + + fn name(&self) -> &str { + "resolve_grouping_function" + } +} + +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[Aggregate::INTERNAL_GROUPING_ID]] +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + +fn replace_grouping_exprs( + input: Arc, + schema: DFSchemaRef, + group_expr: Vec, + aggr_expr: Vec, +) -> Result { + // Create HashMap from Expr to index in the grouping_id bitmap + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; + let columns = schema.columns(); + let mut new_agg_expr = Vec::new(); + let mut projection_exprs = Vec::new(); + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; + projection_exprs.extend( + columns + .iter() + .take(group_expr_len) + .map(|column| Expr::Column(column.clone())), + ); + for (expr, column) in aggr_expr + .into_iter() + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) + { + match expr { + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + column.relation, + column.name, + ))); + } + _ => { + projection_exprs.push(Expr::Column(column)); + new_agg_expr.push(expr); + } + } + } + // Recreate aggregate without grouping functions + let new_aggregate = + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + // Create projection with grouping functions calculations + let projection = LogicalPlan::Projection(Projection::try_new( + projection_exprs, + new_aggregate.into(), + )?); + Ok(projection) +} + +fn analyze_internal(plan: LogicalPlan) -> Result> { + // rewrite any subqueries in the plan first + let transformed_plan = + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; + + let transformed_plan = transformed_plan.transform_data(|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, + )), + _ => Ok(Transformed::no(plan)), + })?; + + Ok(transformed_plan) +} + +fn is_grouping_function(expr: &Expr) -> bool { + // TODO: Do something better than name here should grouping be a built + // in expression? + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") +} + +fn contains_grouping_function(exprs: &[Expr]) -> bool { + exprs.iter().any(is_grouping_function) +} + +/// Validate that the arguments to the grouping function are in the group by clause. +fn validate_args( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, +) -> Result<()> { + let expr_not_in_group_by = function + .args + .iter() + .find(|expr| !group_by_expr.contains_key(expr)); + if let Some(expr) = expr_not_in_group_by { + plan_err!( + "Argument {} to grouping function is not in grouping columns {}", + expr, + group_by_expr.keys().map(|e| e.to_string()).join(", ") + ) + } else { + Ok(()) + } +} + +fn grouping_function_on_id( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, + is_grouping_set: bool, +) -> Result { + validate_args(function, group_by_expr)?; + let args = &function.args; + + // Postgres allows grouping function for group by without grouping sets, the result is then + // always 0 + if !is_grouping_set { + return Ok(Expr::Literal(ScalarValue::from(0i32))); + } + + let group_by_expr_count = group_by_expr.len(); + let literal = |value: usize| { + if group_by_expr_count < 8 { + Expr::Literal(ScalarValue::from(value as u8)) + } else if group_by_expr_count < 16 { + Expr::Literal(ScalarValue::from(value as u16)) + } else if group_by_expr_count < 32 { + Expr::Literal(ScalarValue::from(value as u32)) + } else { + Expr::Literal(ScalarValue::from(value as u64)) + } + }; + + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + // The grouping call is exactly our internal grouping id + if args.len() == group_by_expr_count + && args + .iter() + .rev() + .enumerate() + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) + { + return Ok(cast(grouping_id_column, DataType::Int32)); + } + + args.iter() + .rev() + .enumerate() + .map(|(arg_idx, expr)| { + group_by_expr.get(expr).map(|group_by_idx| { + let group_by_bit = + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); + match group_by_idx.cmp(&arg_idx) { + Ordering::Less => { + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) + } + Ordering::Greater => { + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) + } + Ordering::Equal => group_by_bit, + } + }) + }) + .collect::>>() + .and_then(|bit_exprs| { + bit_exprs + .into_iter() + .reduce(bitwise_or) + .map(|expr| cast(expr, DataType::Int32)) + }) + .ok_or_else(|| { + internal_datafusion_err!("Grouping sets should contains at least one element") + }) +} diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index c771f31a58b21..0ffc954388f5a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; @@ -24,10 +22,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; -use datafusion_expr::{ - Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, - Window, -}; +use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window}; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, @@ -98,7 +93,7 @@ pub fn check_subquery_expr( ) }?; } - check_correlations_in_subquery(inner_plan, true) + check_correlations_in_subquery(inner_plan) } else { if let Expr::InSubquery(subquery) = expr { // InSubquery should only return one column @@ -118,28 +113,22 @@ pub fn check_subquery_expr( | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, Window functions, Aggregate and Join plan nodes" + Projection, Filter, Window functions, Aggregate and Join plan nodes, \ + but was used in [{}]", + outer_plan.display() ), }?; - check_correlations_in_subquery(inner_plan, false) + check_correlations_in_subquery(inner_plan) } } // Recursively check the unsupported outer references in the sub query plan. -fn check_correlations_in_subquery( - inner_plan: &LogicalPlan, - is_scalar: bool, -) -> Result<()> { - check_inner_plan(inner_plan, is_scalar, false, true) +fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { + check_inner_plan(inner_plan, true) } // Recursively check the unsupported outer references in the sub query plan. -fn check_inner_plan( - inner_plan: &LogicalPlan, - is_scalar: bool, - is_aggregate: bool, - can_contain_outer_ref: bool, -) -> Result<()> { +fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); } @@ -147,32 +136,18 @@ fn check_inner_plan( match inner_plan { LogicalPlan::Aggregate(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) - .into_iter() - .partition(|e| e.contains_outer()); - let maybe_unsupported = correlated - .into_iter() - .filter(|expr| !can_pullup_over_aggregation(expr)) - .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupported.is_empty() { - return plan_err!( - "Correlated column is not allowed in predicate: {predicate}" - ); - } - check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) + LogicalPlan::Filter(Filter { input, .. }) => { + check_inner_plan(input, can_contain_outer_ref) } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -180,7 +155,6 @@ fn check_inner_plan( LogicalPlan::Projection(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) | LogicalPlan::EmptyRelation(_) @@ -189,7 +163,7 @@ fn check_inner_plan( | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -202,27 +176,22 @@ fn check_inner_plan( }) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { - check_inner_plan( - plan, - is_scalar, - is_aggregate, - can_contain_outer_ref, - )?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - check_inner_plan(left, is_scalar, is_aggregate, can_contain_outer_ref)?; - check_inner_plan(right, is_scalar, is_aggregate, false) + check_inner_plan(left, can_contain_outer_ref)?; + check_inner_plan(right, false) } JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - check_inner_plan(left, is_scalar, is_aggregate, false)?; - check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) + check_inner_plan(left, false)?; + check_inner_plan(right, can_contain_outer_ref) } JoinType::Full => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, false)?; + check_inner_plan(plan, false)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -291,34 +260,6 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { Ok(exprs) } -/// Check whether the expression can pull up over the aggregation without change the result of the query -fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr - { - match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => - { - !right.any_column_refs() - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => - { - !left.any_column_refs() - } - (_, _) => false, - } - } else { - false - } -} - /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { let mixed = window @@ -364,7 +305,7 @@ mod test { vec![] } - fn schema(&self) -> &datafusion_common::DFSchemaRef { + fn schema(&self) -> &DFSchemaRef { &self.empty_schema } @@ -385,6 +326,10 @@ mod test { empty_schema: Arc::clone(&self.empty_schema), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] @@ -395,6 +340,6 @@ mod test { }), }); - check_inner_plan(&plan, false, false, true).unwrap(); + check_inner_plan(&plan, true).unwrap(); } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4dc34284c7198..5d33b58a02411 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -169,6 +170,7 @@ impl<'a> TypeCoercionRewriter<'a> { match plan { LogicalPlan::Join(join) => self.coerce_join(join), LogicalPlan::Union(union) => Self::coerce_union(union), + LogicalPlan::Limit(limit) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -230,6 +232,37 @@ impl<'a> TypeCoercionRewriter<'a> { })) } + /// Coerce the fetch and skip expression to Int64 type. + fn coerce_limit(limit: Limit) -> Result { + fn coerce_limit_expr( + expr: Expr, + schema: &DFSchema, + expr_name: &str, + ) -> Result { + let dt = expr.get_type(schema)?; + if dt.is_integer() || dt.is_null() { + expr.cast_to(&DataType::Int64, schema) + } else { + plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + } + } + + let empty_schema = DFSchema::empty(); + let new_fetch = limit + .fetch + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT")) + .transpose()?; + let new_skip = limit + .skip + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) + .transpose()?; + Ok(LogicalPlan::Limit(Limit { + input: limit.input, + fetch: new_fetch.map(Box::new), + skip: new_skip.map(Box::new), + })) + } + fn coerce_join_filter(&self, expr: Expr) -> Result { let expr_type = expr.get_type(self.schema)?; match expr_type { @@ -456,7 +489,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { self.schema, &func, )?; - let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), ))) @@ -664,20 +696,20 @@ fn coerce_window_frame( expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|s| s.expr.get_type(schema)) - .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { + let current_types = expressions + .first() + .map(|s| s.expr.get_type(schema)) + .transpose()?; + if let Some(col_type) = current_types { if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) + || is_utf8_or_large_utf8(&col_type) || matches!(col_type, DataType::Null) { col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + } else if is_datetime(&col_type) { + DataType::Interval(IntervalUnit::MonthDayNano) } else { return internal_err!( "Cannot run range queries on datatype: {col_type:?}" @@ -687,10 +719,11 @@ fn coerce_window_frame( return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64, }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; + window_frame.start_bound = + coerce_frame_bound(&target_type, window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?; Ok(window_frame) } @@ -756,30 +789,6 @@ fn coerce_arguments_for_signature_with_aggregate_udf( .collect() } -fn coerce_arguments_for_fun( - expressions: Vec, - schema: &DFSchema, - fun: &Arc, -) -> Result> { - // Cast Fixedsizelist to List for array functions - if fun.name() == "make_array" { - expressions - .into_iter() - .map(|expr| { - let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { - let to_type = DataType::List(Arc::clone(&field)); - expr.cast_to(&to_type, schema) - } else { - Ok(expr) - } - }) - .collect() - } else { - Ok(expressions) - } -} - fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // @@ -1234,7 +1243,7 @@ mod test { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(Utf8) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { @@ -1437,7 +1446,7 @@ mod test { cast(lit("2002-05-08"), DataType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ @@ -1453,7 +1462,7 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = @@ -1508,7 +1517,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1516,7 +1525,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1536,7 +1545,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1544,7 +1553,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1572,7 +1581,7 @@ mod test { let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); @@ -1590,7 +1599,7 @@ mod test { #[test] fn concat_for_type_coercion() -> Result<()> { - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature @@ -1725,7 +1734,7 @@ mod test { true, ), Field::new("binary", DataType::Binary, true), - Field::new("string", DataType::Utf8, true), + Field::new("string", Utf8, true), Field::new("decimal", DataType::Decimal128(10, 10), true), ] .into(), @@ -1742,7 +1751,7 @@ mod test { else_expr: None, }; let case_when_common_type = DataType::Boolean; - let then_else_common_type = DataType::Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1761,8 +1770,8 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = DataType::Utf8; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1852,7 +1861,7 @@ mod test { Some("list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1860,7 +1869,7 @@ mod test { Some("large_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1868,7 +1877,7 @@ mod test { Some("list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1876,7 +1885,7 @@ mod test { Some("fixed_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1884,7 +1893,7 @@ mod test { Some("fixed_list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1892,7 +1901,7 @@ mod test { Some("large_list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c13cb3a8e9734..ee9ae9fb15a79 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,8 +17,8 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions -use std::collections::{BTreeSet, HashMap}; -use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::collections::BTreeSet; +use std::fmt::Debug; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; @@ -26,11 +26,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; -use datafusion_common::hash_utils::combine_hashes; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; + +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ @@ -38,81 +36,9 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::tree_node::replace_sort_expressions; use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator}; -use indexmap::IndexMap; const CSE_PREFIX: &str = "__common_expr"; -/// Identifier that represents a subexpression tree. -/// -/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and -/// "have no collision (as low as possible)" -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -struct Identifier<'n> { - // Hash of `expr` built up incrementally during the first, visiting traversal, but its - // value is not necessarily equal to `expr.hash()`. - hash: u64, - expr: &'n Expr, -} - -impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, random_state: &RandomState) -> Self { - let mut hasher = random_state.build_hasher(); - expr.hash_node(&mut hasher); - let hash = hasher.finish(); - Self { hash, expr } - } - - fn combine(mut self, other: Option) -> Self { - other.map_or(self, |other_id| { - self.hash = combine_hashes(self.hash, other_id.hash); - self - }) - } -} - -impl Hash for Identifier<'_> { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - -/// A cache that contains the postorder index and the identifier of expression tree nodes -/// by the preorder index of the nodes. -/// -/// This cache is filled by `ExprIdentifierVisitor` during the first traversal and is used -/// by `CommonSubexprRewriter` during the second traversal. -/// -/// The purpose of this cache is to quickly find the identifier of a node during the -/// second traversal. -/// -/// Elements in this array are added during `f_down` so the indexes represent the preorder -/// index of expression nodes and thus element 0 belongs to the root of the expression -/// tree. -/// The elements of the array are tuples that contain: -/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start -/// from 0. -/// - Identifier of the expression. If empty (`""`), expr should not be considered for -/// CSE. -/// -/// # Example -/// An expression like `(a + b)` would have the following `IdArray`: -/// ```text -/// [ -/// (2, "a + b"), -/// (1, "a"), -/// (0, "b") -/// ] -/// ``` -type IdArray<'n> = Vec<(usize, Option>)>; - -/// A map that contains the number of normal and conditional occurrences of expressions by -/// their identifiers. -type ExprStats<'n> = HashMap, (usize, usize)>; - -/// A map that contains the common expressions and their alias extracted during the -/// second, rewriting traversal. -type CommonExprs<'n> = IndexMap, (Expr, String)>; - /// Performs Common Sub-expression Elimination optimization. /// /// This optimization improves query performance by computing expressions that @@ -140,168 +66,11 @@ type CommonExprs<'n> = IndexMap, (Expr, String)>; /// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once /// ``` #[derive(Debug)] -pub struct CommonSubexprEliminate { - random_state: RandomState, -} - -/// The result of potentially rewriting a list of expressions to eliminate common -/// subexpressions. -#[derive(Debug)] -enum FoundCommonExprs { - /// No common expressions were found - No { original_exprs_list: Vec> }, - /// Common expressions were found - Yes { - /// extracted common expressions - common_exprs: Vec<(Expr, String)>, - /// new expressions with common subexpressions replaced - new_exprs_list: Vec>, - /// original expressions - original_exprs_list: Vec>, - }, -} +pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { pub fn new() -> Self { - Self { - random_state: RandomState::new(), - } - } - - /// Returns the identifier list for each element in `exprs` and a flag to indicate if - /// rewrite phase of CSE make sense. - /// - /// Returns and array with 1 element for each input expr in `exprs` - /// - /// Each element is itself the result of [`CommonSubexprEliminate::expr_to_identifier`] for that expr - /// (e.g. the identifiers for each node in the tree) - fn to_arrays<'n>( - &self, - exprs: &'n [Expr], - expr_stats: &mut ExprStats<'n>, - expr_mask: ExprMask, - ) -> Result<(bool, Vec>)> { - let mut found_common = false; - exprs - .iter() - .map(|e| { - let mut id_array = vec![]; - self.expr_to_identifier(e, expr_stats, &mut id_array, expr_mask) - .map(|fc| { - found_common |= fc; - - id_array - }) - }) - .collect::>>() - .map(|id_arrays| (found_common, id_arrays)) - } - - /// Add an identifier to `id_array` for every subexpression in this tree. - fn expr_to_identifier<'n>( - &self, - expr: &'n Expr, - expr_stats: &mut ExprStats<'n>, - id_array: &mut IdArray<'n>, - expr_mask: ExprMask, - ) -> Result { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - conditional: false, - }; - expr.visit(&mut visitor)?; - - Ok(visitor.found_common) - } - - /// Rewrites `exprs_list` with common sub-expressions replaced with a new - /// column. - /// - /// `common_exprs` is updated with any sub expressions that were replaced. - /// - /// Returns the rewritten expressions - fn rewrite_exprs_list<'n>( - &self, - exprs_list: Vec>, - arrays_list: &[Vec>], - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, - ) -> Result>> { - exprs_list - .into_iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { - exprs - .into_iter() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr( - expr, - id_array, - expr_stats, - common_exprs, - alias_generator, - ) - }) - .collect::>>() - }) - .collect::>>() - } - - /// Extracts common sub-expressions and rewrites `exprs_list`. - /// - /// Returns `FoundCommonExprs` recording the result of the extraction - fn find_common_exprs( - &self, - exprs_list: Vec>, - config: &dyn OptimizerConfig, - expr_mask: ExprMask, - ) -> Result> { - let mut found_common = false; - let mut expr_stats = ExprStats::new(); - let id_arrays_list = exprs_list - .iter() - .map(|exprs| { - self.to_arrays(exprs, &mut expr_stats, expr_mask).map( - |(fc, id_arrays)| { - found_common |= fc; - - id_arrays - }, - ) - }) - .collect::>>()?; - if found_common { - let mut common_exprs = CommonExprs::new(); - let new_exprs_list = self.rewrite_exprs_list( - // Must clone as Identifiers use references to original expressions so we have - // to keep the original expressions intact. - exprs_list.clone(), - &id_arrays_list, - &expr_stats, - &mut common_exprs, - config.alias_generator().as_ref(), - )?; - assert!(!common_exprs.is_empty()); - - Ok(Transformed::yes(FoundCommonExprs::Yes { - common_exprs: common_exprs.into_values().collect(), - new_exprs_list, - original_exprs_list: exprs_list, - })) - } else { - Ok(Transformed::no(FoundCommonExprs::No { - original_exprs_list: exprs_list, - })) - } + Self {} } fn try_optimize_proj( @@ -372,80 +141,83 @@ impl CommonSubexprEliminate { get_consecutive_window_exprs(window); // Extract common sub-expressions from the list. - self.find_common_exprs(window_expr_list, config, ExprMask::Normal)? - .map_data(|common| match common { - // If there are common sub-expressions, then the insert a projection node - // with the common expressions between the new window nodes and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - new_exprs_list, - original_exprs_list, - } => { - build_common_expr_project_plan(input, common_exprs).map(|new_input| { - (new_exprs_list, new_input, Some(original_exprs_list)) + + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(window_expr_list)? + { + // If there are common sub-expressions, then the insert a projection node + // with the common expressions between the new window nodes and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: new_exprs_list, + original_nodes_list: original_exprs_list, + } => build_common_expr_project_plan(input, common_exprs).map(|new_input| { + Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list))) + }), + FoundCommonNodes::No { + original_nodes_list: original_exprs_list, + } => Ok(Transformed::no((original_exprs_list, input, None))), + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok((new_window_expr_list, new_input, window_expr_list)) + }) + })? + // Rebuild the consecutive window nodes. + .map_data(|(new_window_expr_list, new_input, window_expr_list)| { + // If there were common expressions extracted, then we need to make sure + // we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around extracted + // common expressions this doesn't mean that the original column names + // (schema) are preserved due to the inserted aliases are not always at + // the top of the expression. + // Let's consider improving `find_common_exprs()` to always keep column + // names and get rid of additional name preserving logic here. + if let Some(window_expr_list) = window_expr_list { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = window_expr_list + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>() }) - } - FoundCommonExprs::No { - original_exprs_list, - } => Ok((original_exprs_list, input, None)), - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok((new_window_expr_list, new_input, window_expr_list)) - }) - })? - // Rebuild the consecutive window nodes. - .map_data(|(new_window_expr_list, new_input, window_expr_list)| { - // If there were common expressions extracted, then we need to make sure - // we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around extracted - // common expressions this doesn't mean that the original column names - // (schema) are preserved due to the inserted aliases are not always at - // the top of the expression. - // Let's consider improving `find_common_exprs()` to always keep column - // names and get rid of additional name preserving logic here. - if let Some(window_expr_list) = window_expr_list { - let name_preserver = NamePreserver::new_for_projection(); - let saved_names = window_expr_list - .iter() - .map(|exprs| { - exprs - .iter() - .map(|expr| name_preserver.save(expr)) - .collect::>() - }) - .collect::>(); - new_window_expr_list.into_iter().zip(saved_names).try_rfold( - new_input, - |plan, (new_window_expr, saved_names)| { - let new_window_expr = new_window_expr - .into_iter() - .zip(saved_names) - .map(|(new_window_expr, saved_name)| { - saved_name.restore(new_window_expr) - }) - .collect::>(); - Window::try_new(new_window_expr, Arc::new(plan)) - .map(LogicalPlan::Window) - }, - ) - } else { - new_window_expr_list - .into_iter() - .zip(window_schemas) - .try_rfold(new_input, |plan, (new_window_expr, schema)| { - Window::try_new_with_schema( - new_window_expr, - Arc::new(plan), - schema, - ) + .collect::>(); + new_window_expr_list.into_iter().zip(saved_names).try_rfold( + new_input, + |plan, (new_window_expr, saved_names)| { + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>(); + Window::try_new(new_window_expr, Arc::new(plan)) .map(LogicalPlan::Window) - }) - } - }) + }, + ) + } else { + new_window_expr_list + .into_iter() + .zip(window_schemas) + .try_rfold(new_input, |plan, (new_window_expr, schema)| { + Window::try_new_with_schema( + new_window_expr, + Arc::new(plan), + schema, + ) + .map(LogicalPlan::Window) + }) + } + }) } fn try_optimize_aggregate( @@ -462,174 +234,175 @@ impl CommonSubexprEliminate { } = aggregate; let input = Arc::unwrap_or_clone(input); // Extract common sub-expressions from the aggregate and grouping expressions. - self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? - .map_data(|common| { - match common { - // If there are common sub-expressions, then insert a projection node - // with the common expressions between the new aggregate node and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let new_aggr_expr = new_exprs_list.pop().unwrap(); - let new_group_expr = new_exprs_list.pop().unwrap(); - - build_common_expr_project_plan(input, common_exprs).map( - |new_input| { - let aggr_expr = original_exprs_list.pop().unwrap(); - ( - new_aggr_expr, - new_group_expr, - new_input, - Some(aggr_expr), - ) - }, - ) - } - - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_aggr_expr = original_exprs_list.pop().unwrap(); - let new_group_expr = original_exprs_list.pop().unwrap(); - - Ok((new_aggr_expr, new_group_expr, input, None)) - } - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok(( + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![group_expr, aggr_expr])? + { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + Transformed::yes(( new_aggr_expr, new_group_expr, - aggr_expr, - Arc::new(new_input), + new_input, + Some(aggr_expr), )) }) - })? - // Try extracting common aggregate expressions and rebuild the aggregate node. - .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + } + + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); + + Ok(Transformed::no(( + new_aggr_expr, + new_group_expr, + input, + None, + ))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data( + |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { // Extract common aggregate sub-expressions from the aggregate expressions. - self.find_common_exprs( - vec![new_aggr_expr], - config, + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), ExprMask::NormalAndAggregates, - )? - .map_data(|common| { - match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); - let new_aggr_expr = original_exprs_list.pop().unwrap(); - - let mut agg_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); + )) + .extract_common_nodes(vec![new_aggr_expr])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &mut proj_exprs) - } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) - { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = - expr_rewritten - { - agg_exprs.push(expr.alias(&name)); - proj_exprs - .push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = - config.alias_generator().next(CSE_PREFIX); - let (qualifier, field_name) = - expr_rewritten.qualified_name(); - let out_name = qualified_name( - qualifier.as_ref(), - &field_name, - ); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)) - .alias(out_name), - ); - } + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &mut proj_exprs) + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field_name) = + expr_rewritten.qualified_name(); + let out_name = + qualified_name(qualifier.as_ref(), &field_name); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); } + } else { + proj_exprs.push(expr_rewritten); } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) } - // If there aren't any common aggregate sub-expressions, then just - // rebuild the aggregate node. - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); - - // If there were common expressions extracted, then we need to - // make sure we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around - // extracted common expressions this doesn't mean that the - // original column names (schema) are preserved due to the - // inserted aliases are not always at the top of the - // expression. - // Let's consider improving `find_common_exprs()` to always - // keep column names and get rid of additional name - // preserving logic here. - if let Some(aggr_expr) = aggr_expr { - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::>(); - let new_aggr_expr = rewritten_aggr_expr - .into_iter() - .zip(saved_names) - .map(|(new_expr, saved_name)| { - saved_name.restore(new_expr) - }) - .collect::>(); - - // Since `group_expr` may have changed, schema may also. - // Use `try_new()` method. - Aggregate::try_new( - new_input, - new_group_expr, - new_aggr_expr, - ) - .map(LogicalPlan::Aggregate) - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - rewritten_aggr_expr, - schema, - ) + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + } + + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>(); + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>(); + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) - } + .map(Transformed::no) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) } } - }) - }) + } + }, + ) } /// Rewrites the expr list and input to remove common subexpressions @@ -653,30 +426,34 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { // Extract common sub-expressions from the expressions. - self.find_common_exprs(vec![exprs], config, ExprMask::Normal)? - .map_data(|common| match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - original_exprs_list: _, - } => { - let new_exprs = new_exprs_list.pop().unwrap(); - build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs, new_input)) - } - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_exprs = original_exprs_list.pop().unwrap(); - Ok((new_exprs, input)) - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_exprs, new_input)| { - self.rewrite(new_input, config)? - .map_data(|new_input| Ok((new_exprs, new_input))) - }) + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![exprs])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| Transformed::yes((new_exprs, new_input))) + } + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); + Ok(Transformed::no((new_exprs, input))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_exprs, new_input)| { + self.rewrite(new_input, config)? + .map_data(|new_input| Ok((new_exprs, new_input))) + }) } } @@ -757,7 +534,6 @@ impl OptimizerRule for CommonSubexprEliminate { LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -800,71 +576,6 @@ impl OptimizerRule for CommonSubexprEliminate { } } -impl Default for CommonSubexprEliminate { - fn default() -> Self { - Self::new() - } -} - -/// Build the "intermediate" projection plan that evaluates the extracted common -/// expressions. -/// -/// # Arguments -/// input: the input plan -/// -/// common_exprs: which common subexpressions were used (and thus are added to -/// intermediate projection) -/// -/// expr_stats: the set of common subexpressions -fn build_common_expr_project_plan( - input: LogicalPlan, - common_exprs: Vec<(Expr, String)>, -) -> Result { - let mut fields_set = BTreeSet::new(); - let mut project_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| { - fields_set.insert(expr_alias.clone()); - Ok(expr.alias(expr_alias)) - }) - .collect::>>()?; - - for (qualifier, field) in input.schema().iter() { - if fields_set.insert(qualified_name(qualifier, field.name())) { - project_exprs.push(Expr::from((qualifier, field))); - } - } - - Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -/// Build the projection plan to eliminate unnecessary columns produced by -/// the "intermediate" projection plan built in [build_common_expr_project_plan]. -/// -/// This is required to keep the schema the same for plans that pass the input -/// on to the output, such as `Filter` or `Sort`. -fn build_recover_project_plan( - schema: &DFSchema, - input: LogicalPlan, -) -> Result { - let col_exprs = schema.iter().map(Expr::from).collect(); - Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -fn extract_expressions(expr: &Expr, result: &mut Vec) { - if let Expr::GroupingSet(groupings) = expr { - for e in groupings.distinct_expr() { - let (qualifier, field_name) = e.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)) - } - } else { - let (qualifier, field_name) = expr.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)); - } -} - /// Which type of [expressions](Expr) should be considered for rewriting? #[derive(Debug, Clone, Copy)] enum ExprMask { @@ -882,156 +593,36 @@ enum ExprMask { NormalAndAggregates, } -impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { - let is_normal_minus_aggregates = matches!( - expr, - Expr::Literal(..) - | Expr::Column(..) - | Expr::ScalarVariable(..) - | Expr::Alias(..) - | Expr::Wildcard { .. } - ); - - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } - } -} - -/// Go through an expression tree and generate identifiers for each subexpression. -/// -/// An identifier contains information of the expression itself and its sub-expression. -/// This visitor implementation use a stack `visit_stack` to track traversal, which -/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called -/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` -/// before the first `EnterMark` is considered to be sub-tree of the leaving node. -/// -/// This visitor also records identifier in `id_array`. Makes the following traverse -/// pass can get the identifier of a node without recalculate it. We assign each node -/// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`f_up()`) a node. Has the property -/// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to -/// get the index of `id_array` for each node. -/// -/// `Expr` without sub-expr (column, literal etc.) will not have identifier -/// because they should not be recognized as common sub-expr. -struct ExprIdentifierVisitor<'a, 'n> { - // statistics of expressions - expr_stats: &'a mut ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a mut IdArray<'n>, - // inner states - visit_stack: Vec>, - // preorder index, start from 0. - down_index: usize, - // postorder index, start from 0. - up_index: usize, - // which expression should be skipped? - expr_mask: ExprMask, - // a `RandomState` to generate hashes during the first traversal - random_state: &'a RandomState, - // a flag to indicate that common expression found - found_common: bool, - // if we are in a conditional branch. A conditional branch means that the expression - // might not be executed depending on the runtime values of other expressions, and - // thus can not be extracted as a common expression. - conditional: bool, -} +struct ExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: ExprMask, -/// Record item that used when traversing an expression tree. -enum VisitRecord<'n> { - /// Marks the beginning of expression. It contains: - /// - The post-order index assigned during the first, visiting traversal. - EnterMark(usize), - - /// Marks an accumulated subexpression tree. It contains: - /// - The accumulated identifier of a subexpression. - /// - A boolean flag if the expression is valid for subexpression elimination. - /// The flag is propagated up from children to parent. (E.g. volatile expressions - /// are not valid and can't be extracted, but non-volatile children of volatile - /// expressions can be extracted.) - ExprItem(Identifier<'n>, bool), + // how many aliases have we seen so far + alias_counter: usize, } -impl<'n> ExprIdentifierVisitor<'_, 'n> { - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before - /// it. Returns a tuple that contains: - /// - The pre-order index of the expression we marked. - /// - The accumulated identifier of the children of the marked expression. - /// - An accumulated boolean flag from the children of the marked expression if all - /// children are valid for subexpression elimination (i.e. it is safe to extract the - /// expression as a common expression from its children POV). - /// (E.g. if any of the children of the marked expression is not valid (e.g. is - /// volatile) then the expression is also not valid, so we can propagate this - /// information up from children to parents via `visit_stack` during the first, - /// visiting traversal and no need to test the expression's validity beforehand with - /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { - let mut expr_id = None; - let mut is_valid = true; - - while let Some(item) = self.visit_stack.pop() { - match item { - VisitRecord::EnterMark(down_index) => { - return (down_index, expr_id, is_valid); - } - VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { - expr_id = Some(sub_expr_id.combine(expr_id)); - is_valid &= sub_expr_is_valid; - } - } +impl<'a> ExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self { + Self { + alias_generator, + mask, + alias_counter: 0, } - unreachable!("Enter mark should paired with node number"); - } - - /// Save the current `conditional` status and run `f` with `conditional` set to true. - fn conditionally Result<()>>( - &mut self, - mut f: F, - ) -> Result<()> { - let conditional = self.conditional; - self.conditional = true; - f(self)?; - self.conditional = conditional; - - Ok(()) } } -impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { +impl CSEController for ExprCSEController<'_> { type Node = Expr; - fn f_down(&mut self, expr: &'n Expr) -> Result { - self.id_array.push((0, None)); - self.visit_stack - .push(VisitRecord::EnterMark(self.down_index)); - self.down_index += 1; - - // If an expression can short-circuit then some of its children might not be - // executed so count the occurrence of subexpressions as conditional in all - // children. - Ok(match expr { - // If we are already in a conditionally evaluated subtree then continue - // traversal. - _ if self.conditional => TreeNodeRecursion::Continue, - + fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { + match node { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) if func.short_circuits() => { - self.conditionally(|visitor| { - args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump + Some((vec![], args.iter().collect())) } // In case of `And` and `Or` the first child is surely executed, but we @@ -1040,12 +631,7 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { left, op: Operator::And | Operator::Or, right, - }) => { - left.visit(self)?; - self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; - - TreeNodeRecursion::Jump - } + }) => Some((vec![left.as_ref()], vec![right.as_ref()])), // In case of `Case` the optional base expression and the first when // expressions are surely executed, but we account subexpressions as @@ -1054,167 +640,151 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { expr, when_then_expr, else_expr, - }) => { - expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?; - when_then_expr.iter().take(1).try_for_each(|(when, then)| { - when.visit(self)?; - self.conditionally(|visitor| then.visit(visitor).map(|_| ())) - })?; - self.conditionally(|visitor| { - when_then_expr.iter().skip(1).try_for_each(|(when, then)| { - when.visit(visitor)?; - then.visit(visitor).map(|_| ()) - })?; - else_expr - .iter() - .try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump - } + }) => Some(( + expr.iter() + .map(|e| e.as_ref()) + .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) + .collect(), + when_then_expr + .iter() + .take(1) + .map(|(_, then)| then.as_ref()) + .chain( + when_then_expr + .iter() + .skip(1) + .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]), + ) + .chain(else_expr.iter().map(|e| e.as_ref())) + .collect(), + )), + _ => None, + } + } - // In case of non-short-circuit expressions continue the traversal. - _ => TreeNodeRecursion::Continue, - }) + fn is_valid(node: &Expr) -> bool { + !node.is_volatile_node() } - fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); + fn is_ignored(&self, node: &Expr) -> bool { + let is_normal_minus_aggregates = matches!( + node, + Expr::Literal(..) + | Expr::Column(..) + | Expr::ScalarVariable(..) + | Expr::Alias(..) + | Expr::Wildcard { .. } + ); - let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); - let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; + let is_aggr = matches!(node, Expr::AggregateFunction(..)); - self.id_array[down_index].0 = self.up_index; - if is_valid && !self.expr_mask.ignores(expr) { - self.id_array[down_index].1 = Some(expr_id); - let (count, conditional_count) = - self.expr_stats.entry(expr_id).or_insert((0, 0)); - if self.conditional { - *conditional_count += 1; - } else { - *count += 1; - } - if *count > 1 || (*count == 1 && *conditional_count > 0) { - self.found_common = true; - } + match self.mask { + ExprMask::Normal => is_normal_minus_aggregates || is_aggr, + ExprMask::NormalAndAggregates => is_normal_minus_aggregates, } - self.visit_stack - .push(VisitRecord::ExprItem(expr_id, is_valid)); - self.up_index += 1; - - Ok(TreeNodeRecursion::Continue) } -} -/// Rewrite expression by replacing detected common sub-expression with -/// the corresponding temporary column name. That column contains the -/// evaluate result of replaced expression. -struct CommonSubexprRewriter<'a, 'n> { - // statistics of expressions - expr_stats: &'a ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a IdArray<'n>, - // common expression, that are replaced during the second traversal, are collected to - // this map - common_exprs: &'a mut CommonExprs<'n>, - // preorder index, starts from 0. - down_index: usize, - // how many aliases have we seen so far - alias_counter: usize, - // alias generator for extracted common expressions - alias_generator: &'a AliasGenerator, -} + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } -impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { - type Node = Expr; + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + // alias the expressions without an `Alias` ancestor node + if self.alias_counter > 0 { + col(alias) + } else { + self.alias_counter += 1; + col(alias).alias(node.schema_name().to_string()) + } + } - fn f_down(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { + fn rewrite_f_down(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { self.alias_counter += 1; } + } + fn rewrite_f_up(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { + self.alias_counter -= 1 + } + } +} - let (up_index, expr_id) = self.id_array[self.down_index]; - self.down_index += 1; +impl Default for CommonSubexprEliminate { + fn default() -> Self { + Self::new() + } +} - // Handle `Expr`s with identifiers only - if let Some(expr_id) = expr_id { - let (count, conditional_count) = self.expr_stats.get(&expr_id).unwrap(); - if *count > 1 || *count == 1 && *conditional_count > 0 { - // step index to skip all sub-node (which has smaller series number). - while self.down_index < self.id_array.len() - && self.id_array[self.down_index].0 < up_index - { - self.down_index += 1; - } +/// Build the "intermediate" projection plan that evaluates the extracted common +/// expressions. +/// +/// # Arguments +/// input: the input plan +/// +/// common_exprs: which common subexpressions were used (and thus are added to +/// intermediate projection) +/// +/// expr_stats: the set of common subexpressions +fn build_common_expr_project_plan( + input: LogicalPlan, + common_exprs: Vec<(Expr, String)>, +) -> Result { + let mut fields_set = BTreeSet::new(); + let mut project_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| { + fields_set.insert(expr_alias.clone()); + Ok(expr.alias(expr_alias)) + }) + .collect::>>()?; - let expr_name = expr.schema_name().to_string(); - let (_, expr_alias) = - self.common_exprs.entry(expr_id).or_insert_with(|| { - let expr_alias = self.alias_generator.next(CSE_PREFIX); - (expr, expr_alias) - }); - - // alias the expressions without an `Alias` ancestor node - let rewritten = if self.alias_counter > 0 { - col(expr_alias.clone()) - } else { - self.alias_counter += 1; - col(expr_alias.clone()).alias(expr_name) - }; - - return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); - } + for (qualifier, field) in input.schema().iter() { + if fields_set.insert(qualified_name(qualifier, field.name())) { + project_exprs.push(Expr::from((qualifier, field))); } - - Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { - self.alias_counter -= 1 - } + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) +} - Ok(Transformed::no(expr)) - } +/// Build the projection plan to eliminate unnecessary columns produced by +/// the "intermediate" projection plan built in [build_common_expr_project_plan]. +/// +/// This is required to keep the schema the same for plans that pass the input +/// on to the output, such as `Filter` or `Sort`. +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let col_exprs = schema.iter().map(Expr::from).collect(); + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } -/// Replace common sub-expression in `expr` with the corresponding temporary -/// column name, updating `common_exprs` with any replaced expressions -fn replace_common_expr<'n>( - expr: Expr, - id_array: &IdArray<'n>, - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, -) -> Result { - if id_array.is_empty() { - Ok(Transformed::no(expr)) +fn extract_expressions(expr: &Expr, result: &mut Vec) { + if let Expr::GroupingSet(groupings) = expr { + for e in groupings.distinct_expr() { + let (qualifier, field_name) = e.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)) + } } else { - expr.rewrite(&mut CommonSubexprRewriter { - expr_stats, - id_array, - common_exprs, - down_index: 0, - alias_counter: 0, - alias_generator, - }) + let (qualifier, field_name) = expr.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)); } - .data() } #[cfg(test)] mod test { use std::any::Any; - use std::collections::HashSet; use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, - ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -1238,154 +808,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let optimizer = CommonSubexprEliminate::new(); - - let a_plus_1 = col("a") + lit(1); - let avg_c = avg(col("c")); - let sum_a_plus_1 = sum(a_plus_1); - let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c; - let expr = sum_a_plus_1_minus_avg_c * lit(2); - - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1_minus_avg_c, - .. - }) = &expr - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1, - right: avg_c, - .. - }) = sum_a_plus_1_minus_avg_c.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::AggregateFunction(AggregateFunction { - args: a_plus_1_vec, .. - }) = sum_a_plus_1.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let a_plus_1 = &a_plus_1_vec.as_slice()[0]; - - // skip aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::Normal, - )?; - - // Collect distinct hashes and set them to 0 in `id_array` - fn collect_hashes(id_array: &mut IdArray) -> HashSet { - id_array - .iter_mut() - .flat_map(|(_, expr_id_option)| { - expr_id_option.as_mut().map(|expr_id| { - let hash = expr_id.hash; - expr_id.hash = 0; - hash - }) - }) - .collect::>() - } - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 3); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - (3, None), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - (5, None), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::NormalAndAggregates, - )?; - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 5); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - ( - 3, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1, - }), - ), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - ( - 5, - Some(Identifier { - hash: 0, - expr: avg_c, - }), - ), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 7f918c03e3ac3..6aa59b77f7f94 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -31,7 +31,10 @@ use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; -use datafusion_expr::{expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{ + expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, +}; use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated @@ -49,6 +52,9 @@ pub struct PullUpCorrelatedExpr { pub exists_sub_query: bool, /// Can the correlated expressions be pulled up. Defaults to **TRUE** pub can_pull_up: bool, + /// Indicates if we encounter any correlated expression that can not be pulled up + /// above a aggregation without changing the meaning of the query. + can_pull_over_aggregation: bool, /// Do we need to handle [the Count bug] during the pull up process /// /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 @@ -73,6 +79,7 @@ impl PullUpCorrelatedExpr { in_predicate_opt: None, exists_sub_query: false, can_pull_up: true, + can_pull_over_aggregation: true, need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, @@ -152,6 +159,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { match &plan { LogicalPlan::Filter(plan_filter) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + self.can_pull_over_aggregation = self.can_pull_over_aggregation + && subquery_filter_exprs + .iter() + .filter(|e| e.contains_outer()) + .all(|&e| can_pullup_over_aggregation(e)); let (mut join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; if let Some(in_predicate) = &self.in_predicate_opt { @@ -257,6 +269,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { + // If the aggregation is from a distinct it will not change the result for + // exists/in subqueries so we can still pull up all predicates. + let is_distinct = aggregate.aggr_expr.is_empty(); + if !is_distinct { + self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; + } let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -327,16 +345,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes( - if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), }) - } else { - LogicalPlanBuilder::from((*limit.input).clone()).build()? - }, - ), + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { @@ -384,6 +401,33 @@ impl PullUpCorrelatedExpr { } } +fn can_pullup_over_aggregation(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + fn collect_local_correlated_cols( plan: &LogicalPlan, all_cols_map: &HashMap>, diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index d1ac80003ba71..cc1687cffe921 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,6 +17,7 @@ //! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; +use std::iter; use std::ops::Deref; use std::sync::Arc; @@ -27,16 +28,17 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; +use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; +use itertools::chain; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins @@ -48,79 +50,6 @@ impl DecorrelatePredicateSubquery { pub fn new() -> Self { Self::default() } - - fn rewrite_subquery( - &self, - mut subquery: Subquery, - config: &dyn OptimizerConfig, - ) -> Result { - subquery.subquery = Arc::new( - self.rewrite(Arc::unwrap_or_clone(subquery.subquery), config)? - .data, - ); - Ok(subquery) - } - - /// Finds expressions that have the predicate subqueries (and recurses when found) - /// - /// # Arguments - /// - /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases - /// - /// Returns a tuple (subqueries, non-subquery expressions) - fn extract_subquery_exprs( - &self, - predicate: Expr, - config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction_owned(predicate); // TODO: add ExistenceJoin to support disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.into_iter() { - match it { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - !negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, !negated)); - } - expr => others.push(not(expr)), - }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, negated)); - } - expr => others.push(expr), - } - } - - Ok((subqueries, others)) - } } impl OptimizerRule for DecorrelatePredicateSubquery { @@ -133,69 +62,51 @@ impl OptimizerRule for DecorrelatePredicateSubquery { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + let plan = plan + .map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })? + .data; + let LogicalPlan::Filter(filter) = plan else { return Ok(Transformed::no(plan)); }; - // if there are no subqueries in the predicate, return the original plan - let has_subqueries = - split_conjunction(&filter.predicate) - .iter() - .any(|expr| match expr { - Expr::Not(not_expr) => { - matches!(not_expr.as_ref(), Expr::InSubquery(_) | Expr::Exists(_)) - } - Expr::InSubquery(_) | Expr::Exists(_) => true, - _ => false, - }); - - if !has_subqueries { + if !has_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - let Filter { - predicate, input, .. - } = filter; - let (subqueries, mut other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - if subqueries.is_empty() { + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + if with_subqueries.is_empty() { return internal_err!( "can not find expected subqueries in DecorrelatePredicateSubquery" ); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(input); - for subquery in subqueries { - if let Some(plan) = - build_join(&subquery, &cur_input, config.alias_generator())? - { - cur_input = plan; - } else { - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - let sub_query_expr = match subquery { - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: false, - } => in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: true, - } => not_in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: false, - } => exists(query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: true, - } => not_exists(query.subquery), - }; - other_exprs.push(sub_query_expr); + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? + { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } } @@ -216,6 +127,104 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +fn rewrite_inner_subqueries( + outer: LogicalPlan, + expr: Expr, + config: &dyn OptimizerConfig, +) -> Result<(LogicalPlan, Expr)> { + let mut cur_input = outer; + let alias = config.alias_generator(); + let expr_without_subqueries = expr.transform(|e| match e { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + negated, + }) => { + match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? + { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + } + } + Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }) => { + let in_predicate = subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |output_expr| { + Ok(Expr::eq(*expr.clone(), output_expr)) + })?; + match existence_join( + &cur_input, + Arc::clone(&subquery), + Some(in_predicate), + negated, + alias, + )? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))), + None => Ok(Transformed::no(in_subquery(*expr, subquery))), + } + } + _ => Ok(Transformed::no(e)), + })?; + Ok((cur_input, expr_without_subqueries.data)) +} + +enum SubqueryPredicate { + // The subquery expression is at the top level of the filter and can be fully replaced by a + // semi/anti join + Top(SubqueryInfo), + // The subquery expression is embedded within another expression and is replaced using an + // existence join + Embedded(Expr), +} + +fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { + match expr { + Expr::Not(not_expr) => match *not_expr { + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, !negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) + } + expr => SubqueryPredicate::Embedded(not(expr)), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) + } + expr => SubqueryPredicate::Embedded(expr), + } +} + +fn has_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + /// Optimize the subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// @@ -246,7 +255,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { /// Projection: t2.id /// TableScan: t2 /// ``` -fn build_join( +fn build_join_top( query_info: &SubqueryInfo, left: &LogicalPlan, alias: &Arc, @@ -265,9 +274,70 @@ fn build_join( }) .map_or(Ok(None), |v| v.map(Some))?; + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); + build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) +} + +/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join +/// and checking if the column is null or not. If native support is added for Existence/Mark then +/// we should use that instead. +/// +/// This is used to handle the case when the subquery is embedded in a more complex boolean +/// expression like and OR. For example +/// +/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL +/// Left Join: Filter: t1.id = __correlated_sq_1.id +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.id, true as __exists +/// TableScan: t2 +fn existence_join( + left: &LogicalPlan, + subquery: Arc, + in_predicate_opt: Option, + negated: bool, + alias_generator: &Arc, +) -> Result> { + // Add non nullable column to emulate existence join + let always_true_expr = lit(true).alias("__exists"); + let cols = chain( + subquery.schema().columns().into_iter().map(Expr::Column), + iter::once(always_true_expr), + ); + let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?; + let alias = alias_generator.next("__correlated_sq"); + + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists")); + let exists_expr = if negated { + exists_col.is_null() + } else { + exists_col.is_not_null() + }; + + Ok( + build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)? + .map(|plan| (plan, exists_expr)), + ) +} +fn build_join( + left: &LogicalPlan, + subquery: &LogicalPlan, + in_predicate_opt: Option, + join_type: JoinType, + alias: String, +) -> Result> { let mut pull_up = PullUpCorrelatedExpr::new() .with_in_predicate_opt(in_predicate_opt.clone()) .with_exists_sub_query(in_predicate_opt.is_none()); @@ -278,7 +348,7 @@ fn build_join( } let sub_query_alias = LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.to_string())? + .alias(alias.to_string())? .build()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -287,10 +357,9 @@ fn build_join( .for_each(|cols| all_correlated_cols.extend(cols.clone())); // alias the join filter - let join_filter_opt = - conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) - .map(Option::Some) + let join_filter_opt = conjunction(pull_up.join_filters) + .map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { @@ -302,7 +371,7 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate.and(join_filter)) } @@ -315,17 +384,13 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate) } _ => None, } { // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; @@ -361,6 +426,19 @@ impl SubqueryInfo { negated, } } + + pub fn expr(self) -> Expr { + match self.where_in_expr { + Some(expr) => match self.negated { + true => not_in_subquery(expr, self.query.subquery), + false => in_subquery(expr, self.query.subquery), + }, + None => match self.negated { + true => not_exists(self.query.subquery), + false => exists(self.query.subquery), + }, + } + } } #[cfg(test)] @@ -371,7 +449,7 @@ mod tests { use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{and, binary_expr, col, lit, not, or, out_ref_col, table_scan}; + use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -442,60 +520,6 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for IN subquery with additional OR filter - /// filter expression not modified - #[test] - fn in_subquery_with_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(or( - and( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - ), - in_subquery(col("c"), test_subquery_with_name("sq")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) OR test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - - #[test] - fn in_subquery_with_and_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and( - or( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - in_subquery(col("b"), test_subquery_with_name("sq1")?), - ), - in_subquery(col("c"), test_subquery_with_name("sq2")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq1.c [c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq2.c [c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test for nested IN subqueries #[test] fn in_subquery_nested() -> Result<()> { @@ -512,51 +536,19 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } - /// Test for filter input modification in case filter not supported - /// Outer filter expression not modified while inner converted to join - #[test] - fn in_subquery_input_modified() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq_inner")?))? - .project(vec![col("b"), col("c")])? - .alias("wrapped")? - .filter(or( - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - in_subquery(col("c"), test_subquery_with_name("sq_outer")?), - ))? - .project(vec![col("b")])? - .build()?; - - let expected = "Projection: wrapped.b [b:UInt32]\ - \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq_outer.c [c:UInt32]\ - \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ - \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_inner.c [c:UInt32]\ - \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] @@ -630,13 +622,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; @@ -1003,44 +995,6 @@ mod tests { Ok(()) } - /// Test for correlated IN subquery filter with disjustions - #[test] - fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( - LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(col("orders.o_custkey")), - )? - .project(vec![col("orders.o_custkey")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - in_subquery(col("customer.c_custkey"), sq) - .or(col("customer.c_custkey").eq(lit(1))), - )? - .project(vec![col("customer.c_custkey")])? - .build()?; - - // TODO: support disjunction - for now expect unaltered plan - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey IN () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: outer_ref(customer.c_custkey) = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) - } - /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { @@ -1407,13 +1361,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 550728ddd3f98..65ebac2106ade 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -22,13 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; @@ -51,7 +51,7 @@ impl EliminateCrossJoin { /// Looks like this: /// ```text /// Filter(a.x = b.y AND b.xx = 100) -/// CrossJoin +/// Cross Join /// TableScan a /// TableScan b /// ``` @@ -88,6 +88,7 @@ impl OptimizerRule for EliminateCrossJoin { let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; + let mut all_filters: Vec = vec![]; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -97,7 +98,7 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) | LogicalPlan::CrossJoin(_) + }) ); if !rewriteable { @@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin { Arc::unwrap_or_clone(input), &mut possible_join_keys, &mut all_inputs, + &mut all_filters, )?; extract_possible_join_keys(&predicate, &mut possible_join_keys); @@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); } - flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; None } else { // recursively try to rewrite children @@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin { )); } + if !all_filters.is_empty() { + // Add any filters on top - PushDownFilter can push filters down to applicable join + let first = all_filters.swap_remove(0); + let predicate = all_filters.into_iter().fold(first, and); + left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + } + let Some(predicate) = parent_predicate else { return Ok(Transformed::yes(left)); }; @@ -206,37 +220,25 @@ fn flatten_join_inputs( plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, + all_filters: &mut Vec, ) -> Result<()> { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // checked in can_flatten_join_inputs - if join.filter.is_some() { - return internal_err!( - "should not have filter in inner join in flatten_join_inputs" - ); + if let Some(filter) = join.filter { + all_filters.push(filter); } possible_join_keys.insert_all_owned(join.on); flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, - )?; - } - LogicalPlan::CrossJoin(join) => { - flatten_join_inputs( - Arc::unwrap_or_clone(join.left), - possible_join_keys, - all_inputs, - )?; - flatten_join_inputs( - Arc::unwrap_or_clone(join.right), - possible_join_keys, - all_inputs, + all_filters, )?; } _ => { @@ -253,30 +255,19 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - if join.filter.is_some() { - return false; - } - } - LogicalPlan::CrossJoin(_) => {} + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} _ => return false, }; for child in plan.inputs() { - match child { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !can_flatten_join_inputs(child) { - return false; - } + if let LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) = child + { + if !can_flatten_join_inputs(child) { + return false; } - // the child is not a join/cross join - _ => (), } } true @@ -351,10 +342,15 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::CrossJoin(CrossJoin { + Ok(LogicalPlan::Join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, })) } @@ -462,12 +458,6 @@ mod tests { assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: LogicalPlan) { - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(!transformed_plan.transformed) - } - #[test] fn eliminate_cross_with_simple_and() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -513,7 +503,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -601,7 +591,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -627,7 +617,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -637,8 +627,7 @@ mod tests { } #[test] - /// See https://github.com/apache/datafusion/issues/7530 - fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; let t3 = test_table_scan_with_name("t3")?; @@ -655,7 +644,17 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(plan); + let expected = vec![ + "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" + ]; + + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -843,7 +842,7 @@ mod tests { let expected = vec![ "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", @@ -924,7 +923,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -999,7 +998,7 @@ mod tests { "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -1238,7 +1237,7 @@ mod tests { let expected = vec![ "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index f9b79e036f9b4..789235595dabf 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - CrossJoin, Expr, + Expr, }; /// Eliminates joins when join condition is false. @@ -54,13 +54,6 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -105,21 +98,4 @@ mod tests { let expected = "EmptyRelation"; assert_optimized_plan_equal(plan, expected) } - - #[test] - fn join_on_true() -> Result<()> { - let plan = LogicalPlanBuilder::empty(false) - .join_on( - LogicalPlanBuilder::empty(false).build()?, - Inner, - Some(lit(true)), - )? - .build()?; - - let expected = "\ - CrossJoin:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) - } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 25304d4ccafaa..267615c3e0d93 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; +use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -57,14 +57,16 @@ impl OptimizerRule for EliminateLimit { &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result< - datafusion_common::tree_node::Transformed, - datafusion_common::DataFusionError, - > { + ) -> Result, datafusion_common::DataFusionError> { match plan { LogicalPlan::Limit(limit) => { - if let Some(fetch) = limit.fetch { - if fetch == 0 { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + if let Some(v) = fetch { + if v == 0 { return Ok(Transformed::yes(LogicalPlan::EmptyRelation( EmptyRelation { produce_one_row: false, @@ -72,11 +74,10 @@ impl OptimizerRule for EliminateLimit { }, ))); } - } else if limit.skip == 0 { - // input also can be Limit, so we should apply again. - return Ok(self - .rewrite(Arc::unwrap_or_clone(limit.input), _config) - .unwrap()); + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } Ok(Transformed::no(LogicalPlan::Limit(limit))) } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 3b1df3510d2a4..f310838311254 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -51,7 +51,6 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; -pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5ab427a31699f..42eff7100fbe1 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -367,17 +367,6 @@ fn optimize_projections( right_indices.with_projection_beneficial(), ] } - LogicalPlan::CrossJoin(cross_join) => { - let left_len = cross_join.left.schema().fields().len(); - let (left_indices, right_indices) = - split_join_requirements(left_len, indices, &JoinType::Inner); - // Joins benefit from "small" input tables (lower memory usage). - // Therefore, each child benefits from projection: - vec![ - left_indices.with_projection_beneficial(), - right_indices.with_projection_beneficial(), - ] - } // these nodes are explicitly rewritten in the match statement above LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) @@ -895,6 +884,10 @@ mod tests { // Since schema is same. Output columns requires their corresponding version in the input columns. Some(vec![output_columns.to_vec()]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug, Hash, PartialEq, Eq)] @@ -991,6 +984,10 @@ mod tests { } Some(vec![left_reqs, right_reqs]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 08dcefa22f08a..90a790a0e841a 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -28,7 +28,7 @@ use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -51,7 +51,6 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -251,11 +250,6 @@ impl Optimizer { Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - // simplify expressions does not simplify expressions in subqueries, so we - // run it again after running the optimizations that potentially converted - // subqueries to joins - Arc::new(SimplifyExpressions::new()), - Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), @@ -386,11 +380,9 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( - apply_order, - rule.as_ref(), - config, - )), + Some(apply_order) => new_plan.rewrite_with_subqueries( + &mut Rewriter::new(apply_order, rule.as_ref(), config), + ), // rule handles recursion itself None => optimize_plan_node(new_plan, rule.as_ref(), config), } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index b5e1077ee5bea..d26df073dc6fd 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -72,19 +72,6 @@ impl OptimizerRule for PropagateEmptyRelation { } Ok(Transformed::no(plan)) } - LogicalPlan::CrossJoin(ref join) => { - let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; - if left_empty || right_empty { - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(plan.schema()), - }, - ))); - } - Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone()))) - } - LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4e36cc62588e8..a0262d7d95dfe 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -24,23 +24,19 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, - JoinConstraint, Result, + internal_err, plan_err, qualified_name, Column, DFSchema, Result, }; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{ - CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, -}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; use datafusion_expr::{ - and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - Projection, TableProviderFilterPushDown, + and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, }; use crate::optimizer::ApplyOrder; -use crate::utils::has_all_column_refs; +use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so @@ -562,10 +558,6 @@ fn infer_join_predicates( predicates: &[Expr], on_filters: &[Expr], ) -> Result> { - if join.join_type != JoinType::Inner { - return Ok(vec![]); - } - // Only allow both side key is column. let join_col_keys = join .on @@ -577,55 +569,176 @@ fn infer_join_predicates( }) .collect::>(); - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = predicate.column_refs(); - - for &col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == *l { - join_cols_to_replace.insert(col, *r); - break; - } else if col == *r { - join_cols_to_replace.insert(col, *l); - break; - } - } - } + let join_type = join.join_type; - if join_cols_to_replace.is_empty() { - return None; - } + let mut inferred_predicates = InferredPredicates::new(join_type); - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; + infer_join_predicates_from_predicates( + &join_col_keys, + predicates, + &mut inferred_predicates, + )?; - Some(Ok(join_side_predicate)) - }) - .collect::>>() + infer_join_predicates_from_on_filters( + &join_col_keys, + join_type, + on_filters, + &mut inferred_predicates, + )?; + + Ok(inferred_predicates.predicates) +} + +/// Inferred predicates collector. +/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly +/// filter out NULL, otherwise ignore it. e.g. +/// ```text +/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL; +/// ``` +/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to +/// the left side, resulting in the wrong result. +struct InferredPredicates { + predicates: Vec, + is_inner_join: bool, +} + +impl InferredPredicates { + fn new(join_type: JoinType) -> Self { + Self { + predicates: vec![], + is_inner_join: matches!(join_type, JoinType::Inner), + } + } + + fn try_build_predicate( + &mut self, + predicate: Expr, + replace_map: &HashMap<&Column, &Column>, + ) -> Result<()> { + if self.is_inner_join + || matches!( + is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned() + ), + Ok(true) + ) + { + self.predicates.push(replace_col(predicate, replace_map)?); + } + + Ok(()) + } +} + +/// Infer predicates from the pushed down predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `predicates` the pushed down predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_predicates( + join_col_keys: &[(&Column, &Column)], + predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + infer_join_predicates_impl::( + join_col_keys, + predicates, + inferred_predicates, + ) +} + +/// Infer predicates from the join filter. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `join_type` the JoinType of Join +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_on_filters( + join_col_keys: &[(&Column, &Column)], + join_type: JoinType, + on_filters: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + match join_type { + JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()), + JoinType::Inner => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Left | JoinType::LeftSemi => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Right | JoinType::RightSemi => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + } +} + +/// Infer predicates from the given predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `input_predicates` the given predicates. It can be the pushed down predicates, +/// or it can be the filters of the Join +/// +/// * `inferred_predicates` the inferred results +/// +/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can +/// be inferred from the left table related predicate +/// +/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can +/// be inferred from the right table related predicate +/// +fn infer_join_predicates_impl< + const ENABLE_LEFT_TO_RIGHT: bool, + const ENABLE_RIGHT_TO_LEFT: bool, +>( + join_col_keys: &[(&Column, &Column)], + input_predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + for predicate in input_predicates { + let mut join_cols_to_replace = HashMap::new(); + + for &col in &predicate.column_refs() { + for (l, r) in join_col_keys.iter() { + if ENABLE_LEFT_TO_RIGHT && col == *l { + join_cols_to_replace.insert(col, *r); + break; + } + if ENABLE_RIGHT_TO_LEFT && col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } + if join_cols_to_replace.is_empty() { + continue; + } + + inferred_predicates + .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; + } + Ok(()) } impl OptimizerRule for PushDownFilter { @@ -867,12 +980,6 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::CrossJoin(cross_join) => { - let predicates = split_conjunction_owned(filter.predicate); - let join = convert_cross_join_to_inner_join(cross_join)?; - let plan = push_down_all_join(predicates, vec![], join, vec![])?; - convert_to_cross_join_if_beneficial(plan.data) - } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); let results = scan @@ -1114,48 +1221,6 @@ impl PushDownFilter { } } -/// Converts the given cross join to an inner join with an empty equality -/// predicate and an empty filter condition. -fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { - let CrossJoin { left, right, .. } = cross_join; - let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - Ok(Join { - left, - right, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: vec![], - filter: None, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }) -} - -/// Converts the given inner join with an empty equality predicate and an -/// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial( - plan: LogicalPlan, -) -> Result> { - match plan { - // Can be converted back to cross join - LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { - LogicalPlanBuilder::from(Arc::unwrap_or_clone(join.left)) - .cross_join(Arc::unwrap_or_clone(join.right))? - .build() - .map(Transformed::yes) - } - LogicalPlan::Filter(filter) => { - convert_to_cross_join_if_beneficial(Arc::unwrap_or_clone(filter.input))? - .transform_data(|child_plan| { - Filter::try_new(filter.predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) - .map(Transformed::yes) - }) - } - plan => Ok(Transformed::no(plan)), - } -} - /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -1203,17 +1268,17 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::ScalarValue; + use datafusion_common::{DFSchemaRef, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, ColumnarValue, Extension, ScalarUDF, - ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, - Volatility, + col, in_list, in_subquery, lit, ColumnarValue, Extension, LogicalPlanBuilder, + ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, + UserDefinedLogicalNodeCore, Volatility, }; use crate::optimizer::Optimizer; - use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; + use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; @@ -1235,7 +1300,7 @@ mod tests { expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ - Arc::new(RewriteDisjunctivePredicate::new()), + Arc::new(SimplifyExpressions::new()), Arc::new(PushDownFilter::new()), ]); let optimized_plan = @@ -1499,6 +1564,10 @@ mod tests { schema: Arc::clone(&self.schema), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] @@ -1723,7 +1792,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.d\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ @@ -1750,7 +1819,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.a\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ @@ -2040,7 +2109,7 @@ mod tests { let expected = "\ Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) @@ -2080,7 +2149,7 @@ mod tests { \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(plan, expected) } @@ -2435,7 +2504,7 @@ mod tests { .collect()) } - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } } @@ -2862,6 +2931,46 @@ Projection: a, b assert_optimized_plan_eq(optimized_plan, expected) } + #[test] + fn left_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2" + ); + + // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. + let expected = "\ + Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2903,6 +3012,46 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. + let expected = "\ + Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2944,6 +3093,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn left_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For left anti, filter of the right side filter can be pushed down. + let expected = "\ + Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -2990,6 +3184,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For right anti, filter of the left side can be pushed down. + let expected = "\ + Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 158c7592df516..ec7a0a1364b6a 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -27,6 +27,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; use datafusion_common::Result; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{lit, FetchType, SkipType}; /// Optimization rule that tries to push down `LIMIT`. /// @@ -56,16 +57,27 @@ impl OptimizerRule for PushDownLimit { return Ok(Transformed::no(plan)); }; - let Limit { skip, fetch, input } = limit; + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = input.as_ref() { - let (skip, fetch) = - combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); - + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); let plan = LogicalPlan::Limit(Limit { - skip, - fetch, + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), }); @@ -75,14 +87,10 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch, - input, - }))); + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - match Arc::unwrap_or_clone(input) { + match Arc::unwrap_or_clone(limit.input) { LogicalPlan::TableScan(mut scan) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan @@ -110,13 +118,6 @@ impl OptimizerRule for PushDownLimit { transformed_limit(skip, fetch, LogicalPlan::Union(union)) } - LogicalPlan::CrossJoin(mut cross_join) => { - // push limit to both inputs - cross_join.left = make_arc_limit(0, fetch + skip, cross_join.left); - cross_join.right = make_arc_limit(0, fetch + skip, cross_join.right); - transformed_limit(skip, fetch, LogicalPlan::CrossJoin(cross_join)) - } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) .update_data(|join| { make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) @@ -153,6 +154,29 @@ impl OptimizerRule for PushDownLimit { subquery_alias.input = Arc::new(new_limit); Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) } + LogicalPlan::Extension(extension_plan) + if extension_plan.node.supports_limit_pushdown() => + { + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), + input: Arc::new(child.clone()), + }) + }) + .collect::>(); + + // Create a new extension node with updated inputs + let child_plan = LogicalPlan::Extension(extension_plan); + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + + transformed_limit(skip, fetch, new_extension) + } input => original_limit(skip, fetch, input), } } @@ -180,8 +204,8 @@ impl OptimizerRule for PushDownLimit { /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), + skip: Some(Box::new(lit(skip as i64))), + fetch: Some(Box::new(lit(fetch as i64))), input, }) } @@ -201,11 +225,7 @@ fn original_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) } /// Returns the a transformed limit @@ -214,11 +234,7 @@ fn transformed_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::yes(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } /// Adds a limit to the inputs of a join, if possible @@ -231,15 +247,15 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { - Left | Right | Full => (Some(limit), Some(limit)), + Left | Right | Full | Inner => (Some(limit), Some(limit)), LeftAnti | LeftSemi => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), - Inner => (None, None), } } else { match join.join_type { Left => (Some(limit), None), Right => (None, Some(limit)), + Full => (Some(limit), Some(limit)), _ => (None, None), } }; @@ -258,17 +274,241 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { #[cfg(test)] mod test { + use std::cmp::Ordering; + use std::fmt::{Debug, Formatter}; use std::vec; use super::*; use crate::test::*; - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + + use datafusion_common::DFSchemaRef; + use datafusion_expr::{ + col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, + UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } + #[derive(Debug, PartialEq, Eq, Hash)] + pub struct NoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoopPlan { + fn name(&self) -> &str { + "NoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + true // Allow limit push-down + } + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct NoLimitNoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoLimitNoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoLimitNoopPlan { + fn name(&self) -> &str { + "NoLimitNoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoLimitNoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } + } + #[test] + fn limit_pushdown_basic() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_with_skip() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .build()?; + + let expected = "Limit: skip=10, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1010\ + \n TableScan: test, fetch=1010"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_limits() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .limit(20, Some(500))? + .build()?; + + let expected = "Limit: skip=30, fetch=500\ + \n NoopPlan\ + \n Limit: skip=0, fetch=530\ + \n TableScan: test, fetch=530"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_inputs() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone(), table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_disallowed_noop_plan() -> Result<()> { + let table_scan = test_table_scan()?; + let no_limit_noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoLimitNoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(no_limit_noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoLimitNoopPlan\ + \n TableScan: test"; + + assert_optimized_plan_equal(plan, expected) + } + #[test] fn limit_pushdown_projection_table_provider() -> Result<()> { let table_scan = test_table_scan()?; @@ -868,7 +1108,7 @@ mod test { .build()?; let expected = "Limit: skip=0, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000\ \n Limit: skip=0, fetch=1000\ @@ -888,7 +1128,7 @@ mod test { .build()?; let expected = "Limit: skip=1000, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=2000\ \n TableScan: test, fetch=2000\ \n Limit: skip=0, fetch=2000\ diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index c026130c426f4..f3e1673e72111 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -16,8 +16,10 @@ // under the License. //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` + use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; +use std::sync::Arc; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; @@ -110,7 +112,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf: std::sync::Arc = + let first_value_udaf: Arc = config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { if let Some(order_by) = &sort_expr { diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs deleted file mode 100644 index a6b633fdb8fe6..0000000000000 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; -use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::logical_plan::Filter; -use datafusion_expr::{Expr, LogicalPlan, Operator}; - -/// Optimizer pass that rewrites predicates of the form -/// -/// ```text -/// (A = B AND ) OR (A = B AND ) OR ... (A = B AND ) -/// ``` -/// -/// Into -/// ```text -/// (A = B) AND ( OR OR ... ) -/// ``` -/// -/// Predicates connected by `OR` typically not able to be broken down -/// and distributed as well as those connected by `AND`. -/// -/// The idea is to rewrite predicates into `good_predicate1 AND -/// good_predicate2 AND ...` where `good_predicate` means the -/// predicate has special support in the execution engine. -/// -/// Equality join predicates (e.g. `col1 = col2`), or single column -/// expressions (e.g. `col = 5`) are examples of predicates with -/// special support. -/// -/// # TPCH Q19 -/// -/// This optimization is admittedly somewhat of a niche usecase. It's -/// main use is that it appears in TPCH Q19 and is required to avoid a -/// CROSS JOIN. -/// -/// Specifically, Q19 has a WHERE clause that looks like -/// -/// ```sql -/// where -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// ) -/// ``` -/// -/// Naively planning this query will result in a CROSS join with that -/// single large OR filter. However, rewriting it using the rewrite in -/// this pass results in a proper join predicate, `p_partkey = l_partkey`: -/// -/// ```sql -/// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( -/// ( -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// ) -/// ) -/// ``` -/// -#[derive(Default, Debug)] -pub struct RewriteDisjunctivePredicate; - -impl RewriteDisjunctivePredicate { - pub fn new() -> Self { - Self - } -} - -impl OptimizerRule for RewriteDisjunctivePredicate { - fn name(&self) -> &str { - "rewrite_disjunctive_predicate" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = predicate(filter.predicate)?; - let rewritten_predicate = rewrite_predicate(predicate); - let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - rewritten_expr, - filter.input, - )?))) - } - _ => Ok(Transformed::no(plan)), - } - } -} - -#[derive(Clone, PartialEq, Debug)] -enum Predicate { - And { args: Vec }, - Or { args: Vec }, - Other { expr: Box }, -} - -fn predicate(expr: Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::And { args }) - } - Operator::Or => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::Or { args }) - } - _ => Ok(Predicate::Other { - expr: Box::new(Expr::BinaryExpr(BinaryExpr::new(left, op, right))), - }), - }, - _ => Ok(Predicate::Other { - expr: Box::new(expr), - }), - } -} - -fn normalize_predicate(predicate: Predicate) -> Expr { - match predicate { - Predicate::And { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::and) - .expect("had more than one arg") - } - Predicate::Or { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::or) - .expect("had more than one arg") - } - Predicate::Other { expr } => *expr, - } -} - -fn rewrite_predicate(predicate: Predicate) -> Predicate { - match predicate { - Predicate::And { args } => { - let mut rewritten_args = Vec::with_capacity(args.len()); - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_and_predicates(rewritten_args); - Predicate::And { - args: rewritten_args, - } - } - Predicate::Or { args } => { - let mut rewritten_args = vec![]; - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_or_predicates(rewritten_args); - delete_duplicate_predicates(rewritten_args) - } - Predicate::Other { expr } => Predicate::Other { expr }, - } -} - -fn flatten_and_predicates( - and_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in and_predicates { - match predicate { - Predicate::And { args } => { - flattened_predicates.append(&mut flatten_and_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn flatten_or_predicates( - or_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in or_predicates { - match predicate { - Predicate::Or { args } => { - flattened_predicates.append(&mut flatten_or_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn delete_duplicate_predicates(or_predicates: Vec) -> Predicate { - let mut shortest_exprs: Vec = vec![]; - let mut shortest_exprs_len = 0; - // choose the shortest AND predicate - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - let args_num = args.len(); - if shortest_exprs.is_empty() || args_num < shortest_exprs_len { - shortest_exprs.clone_from(args); - shortest_exprs_len = args_num; - } - } - _ => { - // if there is no AND predicate, it must be the shortest expression. - shortest_exprs = vec![or_predicate.clone()]; - break; - } - } - } - - // dedup shortest_exprs - shortest_exprs.dedup(); - - // Check each element in shortest_exprs to see if it's in all the OR arguments. - let mut exist_exprs: Vec = vec![]; - for expr in shortest_exprs.iter() { - let found = or_predicates.iter().all(|or_predicate| match or_predicate { - Predicate::And { args } => args.contains(expr), - _ => or_predicate == expr, - }); - if found { - exist_exprs.push((*expr).clone()); - } - } - if exist_exprs.is_empty() { - return Predicate::Or { - args: or_predicates, - }; - } - - // Rebuild the OR predicate. - // (A AND B) OR A will be optimized to A. - let mut new_or_predicates = vec![]; - for or_predicate in or_predicates.into_iter() { - match or_predicate { - Predicate::And { mut args } => { - args.retain(|expr| !exist_exprs.contains(expr)); - if !args.is_empty() { - if args.len() == 1 { - new_or_predicates.push(args.remove(0)); - } else { - new_or_predicates.push(Predicate::And { args }); - } - } else { - new_or_predicates.clear(); - break; - } - } - _ => { - if exist_exprs.contains(&or_predicate) { - new_or_predicates.clear(); - break; - } - } - } - } - if !new_or_predicates.is_empty() { - if new_or_predicates.len() == 1 { - exist_exprs.push(new_or_predicates.remove(0)); - } else { - exist_exprs.push(Predicate::Or { - args: flatten_or_predicates(new_or_predicates), - }); - } - } - - if exist_exprs.len() == 1 { - exist_exprs.remove(0) - } else { - Predicate::And { - args: flatten_and_predicates(exist_exprs), - } - } -} - -#[cfg(test)] -mod tests { - use crate::rewrite_disjunctive_predicate::{ - normalize_predicate, predicate, rewrite_predicate, Predicate, - }; - - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{and, col, lit, or}; - - #[test] - fn test_rewrite_predicate() -> Result<()> { - let equi_expr = col("t1.a").eq(col("t2.b")); - let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); - let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); - let expr = or( - and(equi_expr.clone(), gt_expr.clone()), - and(equi_expr.clone(), lt_expr.clone()), - ); - let predicate = predicate(expr)?; - assert_eq!( - predicate, - Predicate::Or { - args: vec![ - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - ] - }, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_predicate = rewrite_predicate(predicate); - assert_eq!( - rewritten_predicate, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Or { - args: vec![ - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_expr = normalize_predicate(rewritten_predicate); - assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); - Ok(()) - } -} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 6409bb9e03f78..2e2c8fb1d6f8c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -318,8 +318,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; // join our sub query into the main plan @@ -625,11 +624,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) != orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -652,11 +661,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) < orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -680,11 +699,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1)"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index a78a54a571235..ce6734616b805 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,14 +32,18 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; +use datafusion_expr::{ + expr::{InList, InSubquery, WindowFunction}, + utils::{iter_conjunction, iter_conjunction_owned}, +}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use indexmap::IndexSet; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -838,21 +842,38 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), - // A OR (A AND B) --> A (if B not null) + // A OR (A AND B) --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { - Transformed::yes(*left) - } - // (A AND B) OR A --> A (if B not null) + }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + // (A AND B) OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { - Transformed::yes(*right) + }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + // Eliminate common factors in conjunctions e.g + // (A AND B) OR (A AND C) -> A AND (B OR C) + Expr::BinaryExpr(BinaryExpr { + left, + op: Or, + right, + }) if has_common_conjunction(&left, &right) => { + let lhs: IndexSet = iter_conjunction_owned(*left).collect(); + let (common, rhs): (Vec<_>, Vec<_>) = + iter_conjunction_owned(*right).partition(|e| lhs.contains(e)); + + let new_rhs = rhs.into_iter().reduce(and); + let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and); + let common_conjunction = common.into_iter().reduce(and).unwrap(); + + let new_expr = match (new_lhs, new_rhs) { + (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)), + (_, _) => common_conjunction, + }; + Transformed::yes(new_expr) } // @@ -911,22 +932,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if expr_contains(&right, &left, And) => Transformed::yes(*right), - // A AND (A OR B) --> A (if B not null) + // A AND (A OR B) --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { - Transformed::yes(*left) - } - // (A OR B) AND A --> A (if B not null) + }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + // (A OR B) AND A --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { - Transformed::yes(*right) - } + }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), // // Rules for Multiply @@ -1028,7 +1045,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(lit(0)) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // @@ -1518,7 +1537,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); @@ -1558,7 +1577,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1578,7 +1597,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1598,7 +1617,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1618,7 +1637,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1638,7 +1657,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1662,6 +1681,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { + let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect(); + iter_conjunction(rhs).any(|e| lhs.contains(&e)) +} + // TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq_and_match_neg( left: &Expr, @@ -1789,6 +1813,8 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { #[cfg(test)] mod tests { + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ function::{ @@ -1799,15 +1825,13 @@ mod tests { *, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - use super::*; // ------------------------------ @@ -2171,11 +2195,11 @@ mod tests { #[test] fn test_simplify_modulo_by_one_non_null() { - let expr = col("c2_non_null") % lit(1); - let expected = lit(0); + let expr = col("c3_non_null") % lit(1); + let expected = lit(0_i64); assert_eq!(simplify(expr), expected); let expr = - col("c2_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); + col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); assert_eq!(simplify(expr), expected); } @@ -2609,15 +2633,11 @@ mod tests { // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) let expr = or(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) - let expr = or(l, r); - - // no rewrites if c1 can be null - let expected = expr.clone(); + let expr = or(r, l); assert_eq!(simplify(expr), expected); } @@ -2648,13 +2668,11 @@ mod tests { // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 let expr = and(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 - let expr = and(l, r); - let expected = expr.clone(); + let expr = and(r, l); assert_eq!(simplify(expr), expected); } @@ -3223,7 +3241,7 @@ mod tests { )], Some(Box::new(col("c2").eq(lit(true)))), )))), - col("c2").or(col("c2").not().and(col("c2"))) // #1716 + col("c2") ); // CASE WHEN ISNULL(c2) THEN true ELSE c2 @@ -3755,11 +3773,52 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + + fn boolean_test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("A", DataType::Boolean, false), + Field::new("B", DataType::Boolean, false), + Field::new("C", DataType::Boolean, false), + Field::new("D", DataType::Boolean, false), + ]) + .to_dfschema_ref() + .unwrap() + } + + #[test] + fn simplify_common_factor_conjuction_in_disjunction() { + let props = ExecutionProps::new(); + let schema = boolean_test_schema(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + + let a = || col("A"); + let b = || col("B"); + let c = || col("C"); + let d = || col("D"); + + // (A AND B) OR (A AND C) -> A AND (B OR C) + let expr = a().and(b()).or(a().and(c())); + let expected = a().and(b().or(c())); + + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D) + let expr = a().and(b()).or(a().and(c())).or(a().and(d())); + let expected = a().and(b().or(c()).or(d())); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // A OR (B AND C AND A) -> A + let expr = a().or(b().and(c().and(a()))); + let expected = a(); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + } + #[test] fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3773,7 +3832,7 @@ mod tests { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3823,7 +3882,7 @@ mod tests { fn accumulator( &self, - _acc_args: function::AccumulatorArgs, + _acc_args: AccumulatorArgs, ) -> Result> { unimplemented!("not needed for tests") } @@ -3853,9 +3912,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3863,9 +3921,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -3910,7 +3967,10 @@ mod tests { } } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!("not needed for tests") } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index c0142ae0fc5a6..200f1f159d813 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -208,7 +208,7 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; assert_optimized_plan_eq(table_scan, expected) } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1c22c2a4375ad..01875349c922a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -279,7 +279,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr::{self, GroupingSet}; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; @@ -288,7 +288,7 @@ mod tests { use datafusion_functions_aggregate::sum::sum_udaf; fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( max_udaf(), vec![expr], true, @@ -355,7 +355,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -373,7 +373,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -392,7 +392,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -569,7 +569,7 @@ mod tests { let table_scan = test_table_scan()?; // sum(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, @@ -612,7 +612,7 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index cabeafd8e7dea..94d07a0791b3b 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -133,20 +133,6 @@ pub fn assert_analyzed_plan_with_config_eq( Ok(()) } -pub fn assert_analyzed_plan_ne( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_ne!(formatted_plan, expected); - - Ok(()) -} - pub fn assert_analyzed_plan_eq_display_indent( rule: Arc, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index 814cd0c0cd0a0..a39f90b5da5db 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -76,4 +76,8 @@ impl UserDefinedLogicalNodeCore for TestUserDefinedPlanNode { input: inputs.swap_remove(0), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 22e3c0ddd076c..31e21d08b569a 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -146,7 +146,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; is_supported_type(&left_type) && is_supported_type(&right_type) - && op.is_comparison_operator() + && op.supports_propagation() } => { match (left.as_mut(), right.as_mut()) { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6972c16c0ddf8..9f325bc01b1d0 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,11 +21,18 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, Result}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use arrow::array::{new_null_array, Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::{logical_plan::LogicalPlan, Expr}; - +use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr}; +use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; +use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -117,3 +124,161 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); trace!("{description}::\n{}\n", plan.display_indent_schema()); } + +/// Determine whether a predicate can restrict NULLs. e.g. +/// `c0 > 8` return true; +/// `c0 IS NULL` return false. +pub fn is_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + if matches!(predicate, Expr::Column(_)) { + return Ok(true); + } + + static DUMMY_COL_NAME: &str = "?"; + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); + let input_schema = DFSchema::try_from(schema.clone())?; + let column = new_null_array(&DataType::Null, 1); + let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?; + let execution_props = ExecutionProps::default(); + let null_column = Column::from_name(DUMMY_COL_NAME); + + let join_cols_to_replace = join_cols_of_predicate + .into_iter() + .map(|column| (column, &null_column)) + .collect::>(); + + let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; + let coerced_predicate = coerce(replaced_predicate, &input_schema)?; + let phys_expr = + create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?; + + let result_type = phys_expr.data_type(&schema)?; + if !matches!(&result_type, DataType::Boolean) { + return Ok(false); + } + + // If result is single `true`, return false; + // If result is single `NULL` or `false`, return true; + Ok(match phys_expr.evaluate(&input_batch)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }) +} + +fn coerce(expr: Expr, schema: &DFSchema) -> Result { + let mut expr_rewrite = TypeCoercionRewriter { schema }; + expr.rewrite(&mut expr_rewrite).data() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; + + #[test] + fn expr_is_restrict_null_predicate() -> Result<()> { + let test_cases = vec![ + // a + (col("a"), true), + // a IS NULL + (is_null(col("a")), false), + // a IS NOT NULL + (Expr::IsNotNull(Box::new(col("a"))), true), + // a = NULL + ( + binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + true, + ), + // a > 8 + (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), + // a <= 8 + (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), + // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .when(lit(0i64), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + true, + ), + // CASE a WHEN 1 THEN true ELSE false END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + true, + ), + // CASE a WHEN 0 THEN false ELSE true END + ( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + false, + ), + // (CASE a WHEN 0 THEN false ELSE true END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + Operator::Or, + lit(false), + ), + false, + ), + // (CASE a WHEN 0 THEN true ELSE false END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + true, + ), + // a IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false), + true, + ), + // a NOT IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true), + true, + ), + // a IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + true, + ), + // a NOT IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + true, + ), + ]; + + let column_a = Column::from_name("a"); + for (predicate, expected) in test_cases { + let join_cols_of_predicate = std::iter::once(&column_a); + let actual = + is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + assert_eq!(actual, expected, "{}", predicate); + } + + Ok(()) + } +} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 470bd947c7fbc..236167985790d 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -345,7 +345,7 @@ fn select_wildcard_with_repeated_column() { let sql = "SELECT *, col_int32 FROM test"; let err = test_sql(sql).expect_err("query should have failed"); assert_eq!( - "expand_wildcard_rule\ncaused by\nError during planning: Projections require unique expression names but the expression \"test.col_int32\" at position 0 and \"test.col_int32\" at position 7 have the same name. Consider aliasing (\"AS\") one of them.", + "Schema error: Schema contains duplicate qualified field name test.col_int32", err.strip_backtrace() ); } @@ -396,7 +396,7 @@ fn test_sql(sql: &str) -> Result { .with_udaf(count_udaf()) .with_udaf(avg_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; let config = OptimizerContext::new().with_skip_failing_rules(false); let analyzer = Analyzer::new(); diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index f320ebcc06b56..80c4963ae0354 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -31,7 +31,7 @@ use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; -use std::mem; +use std::mem::{size_of, swap}; use std::ops::Range; use std::sync::Arc; @@ -104,8 +104,9 @@ impl ArrowBytesSet { /// `Binary`, and `LargeBinary`) values that can produce the set of keys on /// output as `GenericBinaryArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringArray` / `BinaryArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// @@ -259,7 +260,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - mem::swap(self, &mut new_self); + swap(self, &mut new_self); new_self } @@ -544,7 +545,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * mem::size_of::() + + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } @@ -574,7 +575,7 @@ where } /// Maximum size of a value that can be inlined in the hash table -const SHORT_VALUE_LEN: usize = mem::size_of::(); +const SHORT_VALUE_LEN: usize = size_of::(); /// Entry in the hash table -- see [`ArrowBytesMap`] for more details #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index bdcf7bbacc696..c6768a19d30e6 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -88,8 +88,9 @@ impl ArrowBytesViewSet { /// values that can produce the set of keys on /// output as `GenericBinaryViewArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringViewArray` / `BinaryViewArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 704cb291335fa..d825bfe7e2643 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -120,6 +120,13 @@ impl PhysicalSortExpr { } } +/// Access the PhysicalSortExpr as a PhysicalExpr +impl AsRef for PhysicalSortExpr { + fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { + self.expr.as_ref() + } +} + impl PartialEq for PhysicalSortExpr { fn eq(&self, other: &PhysicalSortExpr) -> bool { self.options == other.options && self.expr.eq(&other.expr) @@ -136,7 +143,7 @@ impl Hash for PhysicalSortExpr { } impl Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "{} {}", self.expr, to_str(&self.options)) } } @@ -181,7 +188,7 @@ impl PhysicalSortExpr { pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { struct DisplayableList<'a>(&'a [PhysicalSortExpr]); impl<'a> Display for DisplayableList<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let mut first = true; for sort_expr in self.0 { if first { @@ -253,7 +260,7 @@ impl PartialEq for PhysicalSortRequirement { } impl Display for PhysicalSortRequirement { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let opts_string = self.options.as_ref().map_or("NA", to_str); write!(f, "{} {}", self.expr, opts_string) } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 41e53af61bb9a..079e7d42e93e2 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -35,14 +35,6 @@ workspace = true name = "datafusion_physical_expr" path = "src/lib.rs" -[features] -default = [ - "regex_expressions", - "encoding_expressions", -] -encoding_expressions = ["base64", "hex"] -regex_expressions = ["regex"] - [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -51,23 +43,20 @@ arrow-buffer = { workspace = true } arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-string = { workspace = true } -base64 = { version = "0.22", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } -hex = { version = "0.4", optional = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" -regex = { workspace = true, optional = true } +regex = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 866596d0b6901..6330c240241a3 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -328,7 +328,7 @@ impl AggregateFunctionExpr { /// not implement the method, returns an error. Order insensitive and hard /// requirement aggregators return `Ok(None)`. pub fn with_beneficial_ordering( - self, + self: Arc, beneficial_ordering: bool, ) -> Result> { let Some(updated_fn) = self diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 00708b4540aa6..c1851ddb22b53 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -30,7 +30,6 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; -#[derive(Debug, Clone)] /// A structure representing a expression known to be constant in a physical execution plan. /// /// The `ConstExpr` struct encapsulates an expression that is constant during the execution @@ -41,9 +40,10 @@ use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; /// /// - `expr`: Constant expression for a node in the physical plan. /// -/// - `across_partitions`: A boolean flag indicating whether the constant expression is -/// valid across partitions. If set to `true`, the constant expression has same value for all partitions. -/// If set to `false`, the constant expression may have different values for different partitions. +/// - `across_partitions`: A boolean flag indicating whether the constant +/// expression is the same across partitions. If set to `true`, the constant +/// expression has same value for all partitions. If set to `false`, the +/// constant expression may have different values for different partitions. /// /// # Example /// @@ -56,11 +56,22 @@ use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; /// // create a constant expression from a physical expression /// let const_expr = ConstExpr::from(col); /// ``` +#[derive(Debug, Clone)] pub struct ConstExpr { + /// The expression that is known to be constant (e.g. a `Column`) expr: Arc, + /// Does the constant have the same value across all partitions? See + /// struct docs for more details across_partitions: bool, } +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions + && self.expr.eq(other.expr.as_any()) + } +} + impl ConstExpr { /// Create a new constant expression from a physical expression. /// @@ -74,11 +85,17 @@ impl ConstExpr { } } + /// Set the `across_partitions` flag + /// + /// See struct docs for more details pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { self.across_partitions = across_partitions; self } + /// Is the expression the same across all partitions? + /// + /// See struct docs for more details pub fn across_partitions(&self) -> bool { self.across_partitions } @@ -101,6 +118,31 @@ impl ConstExpr { across_partitions: self.across_partitions, }) } + + /// Returns true if this constant expression is equal to the given expression + pub fn eq_expr(&self, other: impl AsRef) -> bool { + self.expr.eq(other.as_ref().as_any()) + } + + /// Returns a [`Display`]able list of `ConstExpr`. + pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [ConstExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for const_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", const_expr)?; + } + Ok(()) + } + } + DisplayableList(input) + } } /// Display implementation for `ConstExpr` diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 38647f7ca1d4b..95bb93d6ca57f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -72,21 +72,15 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use super::*; use crate::expressions::col; use crate::PhysicalSortExpr; - use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; - use itertools::izip; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; - pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, @@ -175,67 +169,6 @@ mod tests { Ok((test_schema, eq_properties)) } - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; - // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - // Convert each tuple to PhysicalSortRequirement pub fn convert_to_sort_reqs( in_data: &[(&Arc, Option)], @@ -294,33 +227,6 @@ mod tests { .collect() } - // Apply projection to the input_data, return projected equivalence properties and record batch - pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(Arc::clone(&output_schema)) - } else { - RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - #[test] fn add_equal_conditions_test() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -378,168 +284,4 @@ mod tests { Ok(()) } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - pub fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(Arc::clone(&unique_col)); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(Arc::clone(res)); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - pub fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.expr().as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(Arc::clone(&representative_array)); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 65423033d5e01..d71f3b037fb19 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -18,6 +18,7 @@ use std::fmt::Display; use std::hash::Hash; use std::sync::Arc; +use std::vec::IntoIter; use crate::equivalence::add_offset_to_expr; use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; @@ -36,7 +37,7 @@ use arrow_schema::SortOptions; /// /// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct OrderingEquivalenceClass { pub orderings: Vec, } @@ -44,7 +45,7 @@ pub struct OrderingEquivalenceClass { impl OrderingEquivalenceClass { /// Creates new empty ordering equivalence class. pub fn empty() -> Self { - Self { orderings: vec![] } + Default::default() } /// Clears (empties) this ordering equivalence class. @@ -197,6 +198,15 @@ impl OrderingEquivalenceClass { } } +impl IntoIterator for OrderingEquivalenceClass { + type Item = LexOrdering; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.orderings.into_iter() + } +} + /// This function constructs a duplicate-free `LexOrdering` by filtering out /// duplicate entries that have same physical expression inside. For example, /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. @@ -229,10 +239,10 @@ impl Display for OrderingEquivalenceClass { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "{}", PhysicalSortExpr::format_list(ordering))?; + write!(f, "[{}]", PhysicalSortExpr::format_list(ordering))?; } for ordering in iter { - write!(f, "{}", PhysicalSortExpr::format_list(ordering))?; + write!(f, ", [{}]", PhysicalSortExpr::format_list(ordering))?; } write!(f, "]")?; Ok(()) @@ -244,9 +254,7 @@ mod tests { use std::sync::Arc; use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_random_schema, - create_test_params, create_test_schema, generate_table_for_eq_properties, - is_table_same_after_sort, + convert_to_orderings, convert_to_sort_exprs, create_test_schema, }; use crate::equivalence::{ EquivalenceClass, EquivalenceGroup, EquivalenceProperties, @@ -261,8 +269,6 @@ mod tests { use datafusion_common::{DFSchema, Result}; use datafusion_expr::{Operator, ScalarUDF}; - use itertools::Itertools; - #[test] fn test_ordering_satisfy() -> Result<()> { let input_schema = Arc::new(Schema::new(vec![ @@ -593,305 +599,6 @@ mod tests { Ok(()) } - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index ebf26d3262aa2..25a05a2a5918b 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -139,23 +139,18 @@ fn project_index_to_exprs( mod tests { use super::*; use crate::equivalence::tests::{ - apply_projection, convert_to_orderings, convert_to_orderings_owned, - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - output_schema, + convert_to_orderings, convert_to_orderings_owned, output_schema, }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; - use crate::PhysicalSortExpr; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::DFSchema; use datafusion_expr::{Operator, ScalarUDF}; - use itertools::Itertools; - #[test] fn project_orderings() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -987,174 +982,4 @@ mod tests { Ok(()) } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| Arc::clone(target)) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index dc59a1eb835b1..9a16b205ae25b 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::fmt; use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::iter::Peekable; +use std::slice::Iter; use std::sync::Arc; use super::ordering::collapse_lex_ordering; @@ -34,7 +37,7 @@ use crate::{ use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, JoinSide, JoinType, Result}; +use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::utils::ExprPropertiesNode; @@ -118,7 +121,7 @@ use itertools::Itertools; /// PhysicalSortExpr::new_default(col_c).desc(), /// ]); /// -/// assert_eq!(eq_properties.to_string(), "order: [a@0 ASC,c@2 DESC], const: [b@1]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC,c@2 DESC]], const: [b@1]") /// ``` #[derive(Debug, Clone)] pub struct EquivalenceProperties { @@ -279,6 +282,12 @@ impl EquivalenceProperties { self.with_constants(constants) } + /// Remove the specified constant + pub fn remove_constant(mut self, c: &ConstExpr) -> Self { + self.constants.retain(|existing| existing != c); + self + } + /// Track/register physical expressions with constant values. pub fn with_constants( mut self, @@ -701,7 +710,7 @@ impl EquivalenceProperties { /// c ASC: Node {None, HashSet{a ASC}} /// ``` fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = IndexMap::new(); + let mut dependency_map = DependencyMap::new(); for ordering in self.normalized_oeq_class().iter() { for (idx, sort_expr) in ordering.iter().enumerate() { let target_sort_expr = @@ -723,13 +732,11 @@ impl EquivalenceProperties { let dependency = idx.checked_sub(1).map(|a| &ordering[a]); // Add sort expressions that can be projected or referred to // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: IndexSet::new(), - }) - .insert_dependency(dependency); + dependency_map.insert( + sort_expr, + target_sort_expr.as_ref(), + dependency, + ); } if !is_projected { // If we can not project, stop constructing the dependency @@ -1106,7 +1113,7 @@ impl EquivalenceProperties { /// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] /// ``` impl Display for EquivalenceProperties { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.eq_group.is_empty() && self.oeq_class.is_empty() && self.constants.is_empty() @@ -1120,15 +1127,7 @@ impl Display for EquivalenceProperties { write!(f, ", eq: {}", self.eq_group)?; } if !self.constants.is_empty() { - write!(f, ", const: [")?; - let mut iter = self.constants.iter(); - if let Some(c) = iter.next() { - write!(f, "{}", c)?; - } - for c in iter { - write!(f, ", {}", c)?; - } - write!(f, "]")?; + write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; } Ok(()) } @@ -1257,7 +1256,7 @@ fn referred_dependencies( // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: let mut expr_to_sort_exprs = IndexMap::::new(); for sort_expr in dependency_map - .keys() + .sort_exprs() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { let key = ExprWrapper(Arc::clone(&sort_expr.expr)); @@ -1270,10 +1269,16 @@ fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() + let dependencies = expr_to_sort_exprs + .into_values() + .map(Dependencies::into_inner) + .collect::>(); + dependencies + .iter() .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .map(|referred_deps| { + Dependencies::new_from_iter(referred_deps.into_iter().cloned()) + }) .collect() } @@ -1295,21 +1300,32 @@ fn construct_prefix_orderings( relevant_sort_expr: &PhysicalSortExpr, dependency_map: &DependencyMap, ) -> Vec { - dependency_map[relevant_sort_expr] + let mut dep_enumerator = DependencyEnumerator::new(); + dependency_map + .get(relevant_sort_expr) + .expect("no relevant sort expr found") .dependencies .iter() - .flat_map(|dep| construct_orderings(dep, dependency_map)) + .flat_map(|dep| dep_enumerator.construct_orderings(dep, dependency_map)) .collect() } -/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies -/// (`dependency_map`), this function generates all possible prefix orderings -/// based on the given dependencies. +/// Generates all possible orderings where dependencies are satisfied for the +/// current projection expression. +/// +/// # Examaple +/// If `dependences` is `a + b ASC` and the dependency map holds dependencies +/// * `a ASC` --> `[c ASC]` +/// * `b ASC` --> `[d DESC]`, +/// +/// This function generates these two sort orders +/// * `[c ASC, d DESC, a + b ASC]` +/// * `[d DESC, c ASC, a + b ASC]` /// /// # Parameters /// -/// * `dependencies` - A reference to the dependencies. -/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// * `dependencies` - Set of relevant expressions. +/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` /// /// # Returns /// @@ -1335,11 +1351,6 @@ fn generate_dependency_orderings( return vec![vec![]]; } - // Generate all possible orderings where dependencies are satisfied for the - // current projection expression. For example, if expression is `a + b ASC`, - // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` - // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and - // `[d DESC, c ASC, a + b ASC]`. relevant_prefixes .into_iter() .multi_cartesian_product() @@ -1421,7 +1432,7 @@ struct DependencyNode { } impl DependencyNode { - // Insert dependency to the state (if exists). + /// Insert dependency to the state (if exists). fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { if let Some(dep) = dependency { self.dependencies.insert(dep.clone()); @@ -1429,46 +1440,229 @@ impl DependencyNode { } } -// Using `IndexMap` and `IndexSet` makes sure to generate consistent results across different executions for the same query. -// We could have used `HashSet`, `HashMap` in place of them without any loss of functionality. -// As an example, if existing orderings are `[a ASC, b ASC]`, `[c ASC]` for output ordering -// both `[a ASC, b ASC, c ASC]` and `[c ASC, a ASC, b ASC]` are valid (e.g. concatenated version of the alternative orderings). -// When using `HashSet`, `HashMap` it is not guaranteed to generate consistent result, among the possible 2 results in the example above. -type DependencyMap = IndexMap; -type Dependencies = IndexSet; - -/// This function recursively analyzes the dependencies of the given sort -/// expression within the given dependency map to construct lexicographical -/// orderings that include the sort expression and its dependencies. +impl Display for DependencyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(target) = &self.target_sort_expr { + write!(f, "(target: {}, ", target)?; + } else { + write!(f, "(")?; + } + write!(f, "dependencies: [{}])", self.dependencies) + } +} + +/// Maps an expression --> DependencyNode /// -/// # Parameters +/// # Debugging / deplaying `DependencyMap` /// -/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) -/// for which lexicographical orderings satisfying its dependencies are to be -/// constructed. -/// - `dependency_map`: A reference to the `DependencyMap` that contains -/// dependencies for different `PhysicalSortExpr`s. +/// This structure implements `Display` to assist debugging. For example: /// -/// # Returns +/// ```text +/// DependencyMap: { +/// a@0 ASC --> (target: a@0 ASC, dependencies: [[]]) +/// b@1 ASC --> (target: b@1 ASC, dependencies: [[a@0 ASC, c@2 ASC]]) +/// c@2 ASC --> (target: c@2 ASC, dependencies: [[b@1 ASC, a@0 ASC]]) +/// d@3 ASC --> (target: d@3 ASC, dependencies: [[c@2 ASC, b@1 ASC]]) +/// } +/// ``` /// -/// A vector of lexicographical orderings (`Vec`) based on the given -/// sort expression and its dependencies. -fn construct_orderings( - referred_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.clone().unwrap(); - if node.dependencies.is_empty() { - vec![vec![target_sort_expr]] - } else { +/// # Note on IndexMap Rationale +/// +/// Using `IndexMap` (which preserves insert order) to ensure consistent results +/// across different executions for the same query. We could have used +/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// +/// As an example, if existing orderings are +/// 1. `[a ASC, b ASC]` +/// 2. `[c ASC]` for +/// +/// Then both the following output orderings are valid +/// 1. `[a ASC, b ASC, c ASC]` +/// 2. `[c ASC, a ASC, b ASC]` +/// +/// (this are both valid as they are concatenated versions of the alternative +/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate +/// consistent result, among the possible 2 results in the example above. +#[derive(Debug)] +struct DependencyMap { + inner: IndexMap, +} + +impl DependencyMap { + fn new() -> Self { + Self { + inner: IndexMap::new(), + } + } + + /// Insert a new dependency `sort_expr` --> `dependency` into the map. + /// + /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + fn insert( + &mut self, + sort_expr: &PhysicalSortExpr, + target_sort_expr: Option<&PhysicalSortExpr>, + dependency: Option<&PhysicalSortExpr>, + ) { + self.inner + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.cloned(), + dependencies: Dependencies::new(), + }) + .insert_dependency(dependency) + } + + /// Iterator over (sort_expr, DependencyNode) pairs + fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + /// iterator over all sort exprs + fn sort_exprs(&self) -> impl Iterator { + self.inner.keys() + } + + /// Return the dependency node for the given sort expression, if any + fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { + self.inner.get(sort_expr) + } +} + +impl Display for DependencyMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "DependencyMap: {{")?; + for (sort_expr, node) in self.inner.iter() { + writeln!(f, " {sort_expr} --> {node}")?; + } + writeln!(f, "}}") + } +} + +/// A list of sort expressions that can be calculated from a known set of +/// dependencies. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct Dependencies { + inner: IndexSet, +} + +impl Display for Dependencies { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + let mut iter = self.inner.iter(); + if let Some(dep) = iter.next() { + write!(f, "{}", dep)?; + } + for dep in iter { + write!(f, ", {}", dep)?; + } + write!(f, "]") + } +} + +impl Dependencies { + /// Create a new empty `Dependencies` instance. + fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. + fn new_from_iter(iter: impl IntoIterator) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Insert a new dependency into the set. + fn insert(&mut self, sort_expr: PhysicalSortExpr) { + self.inner.insert(sort_expr); + } + + /// Iterator over dependencies in the set + fn iter(&self) -> impl Iterator + Clone { + self.inner.iter() + } + + /// Return the inner set of dependencies + fn into_inner(self) -> IndexSet { + self.inner + } + + /// Returns true if there are no dependencies + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +/// Contains a mapping of all dependencies we have processed for each sort expr +struct DependencyEnumerator<'a> { + /// Maps `expr` --> `[exprs]` that have previously been processed + seen: IndexMap<&'a PhysicalSortExpr, IndexSet<&'a PhysicalSortExpr>>, +} + +impl<'a> DependencyEnumerator<'a> { + fn new() -> Self { + Self { + seen: IndexMap::new(), + } + } + + /// Insert a new dependency, + /// + /// returns false if the dependency was already in the map + /// returns true if the dependency was newly inserted + fn insert( + &mut self, + target: &'a PhysicalSortExpr, + dep: &'a PhysicalSortExpr, + ) -> bool { + self.seen.entry(target).or_default().insert(dep) + } + + /// This function recursively analyzes the dependencies of the given sort + /// expression within the given dependency map to construct lexicographical + /// orderings that include the sort expression and its dependencies. + /// + /// # Parameters + /// + /// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) + /// for which lexicographical orderings satisfying its dependencies are to be + /// constructed. + /// - `dependency_map`: A reference to the `DependencyMap` that contains + /// dependencies for different `PhysicalSortExpr`s. + /// + /// # Returns + /// + /// A vector of lexicographical orderings (`Vec`) based on the given + /// sort expression and its dependencies. + fn construct_orderings( + &mut self, + referred_sort_expr: &'a PhysicalSortExpr, + dependency_map: &'a DependencyMap, + ) -> Vec { + let node = dependency_map + .get(referred_sort_expr) + .expect("`referred_sort_expr` should be inside `dependency_map`"); + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // An empty dependency means the referred_sort_expr represents a global ordering. + // Return its projected version, which is the target_expression. + if node.dependencies.is_empty() { + return vec![vec![target_sort_expr.clone()]]; + }; + node.dependencies .iter() .flat_map(|dep| { - let mut orderings = construct_orderings(dep, dependency_map); + let mut orderings = if self.insert(target_sort_expr, dep) { + self.construct_orderings(dep, dependency_map) + } else { + vec![] + }; + for ordering in orderings.iter_mut() { ordering.push(target_sort_expr.clone()) } @@ -1611,58 +1805,62 @@ impl Hash for ExprWrapper { /// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` /// of `lhs` and `rhs` according to the schema of `lhs`. +/// +/// Rules: The UnionExec does not interleave its inputs: instead it passes each +/// input partition from the children as its own output. +/// +/// Since the output equivalence properties are properties that are true for +/// *all* output partitions, that is the same as being true for all *input* +/// partitions fn calculate_union_binary( - lhs: EquivalenceProperties, + mut lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, ) -> Result { - // TODO: In some cases, we should be able to preserve some equivalence - // classes. Add support for such cases. - // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): if !rhs.schema.eq(&lhs.schema) { rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; } - // First, calculate valid constants for the union. A quantity is constant - // after the union if it is constant in both sides. - let constants = lhs + // First, calculate valid constants for the union. An expression is constant + // at the output of the union if it is constant in both sides. + let constants: Vec<_> = lhs .constants() .iter() .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) .map(|const_expr| { - // TODO: When both sides' constants are valid across partitions, - // the union's constant should also be valid if values are - // the same. However, we do not have the capability to - // check this yet. + // TODO: When both sides have a constant column, and the actual + // constant value is the same, then the output properties could + // reflect the constant is valid across all partitions. However we + // don't track the actual value that the ConstExpr takes on, so we + // can't determine that yet ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) }) .collect(); + // remove any constants that are shared in both outputs (avoid double counting them) + for c in &constants { + lhs = lhs.remove_constant(c); + rhs = rhs.remove_constant(c); + } + // Next, calculate valid orderings for the union by searching for prefixes // in both sides. - let mut orderings = vec![]; - for mut ordering in lhs.normalized_oeq_class().orderings { - // Progressively shorten the ordering to search for a satisfied prefix: - while !rhs.ordering_satisfy(&ordering) { - ordering.pop(); - } - // There is a non-trivial satisfied prefix, add it as a valid ordering: - if !ordering.is_empty() { - orderings.push(ordering); - } - } - for mut ordering in rhs.normalized_oeq_class().orderings { - // Progressively shorten the ordering to search for a satisfied prefix: - while !lhs.ordering_satisfy(&ordering) { - ordering.pop(); - } - // There is a non-trivial satisfied prefix, add it as a valid ordering: - if !ordering.is_empty() { - orderings.push(ordering); - } - } - let mut eq_properties = EquivalenceProperties::new(lhs.schema); - eq_properties.constants = constants; + let mut orderings = UnionEquivalentOrderingBuilder::new(); + orderings.add_satisfied_orderings( + lhs.normalized_oeq_class().orderings, + lhs.constants(), + &rhs, + ); + orderings.add_satisfied_orderings( + rhs.normalized_oeq_class().orderings, + rhs.constants(), + &lhs, + ); + let orderings = orderings.build(); + + let mut eq_properties = + EquivalenceProperties::new(lhs.schema).with_constants(constants); + eq_properties.add_new_orderings(orderings); Ok(eq_properties) } @@ -1677,14 +1875,222 @@ pub fn calculate_union( ) -> Result { // TODO: In some cases, we should be able to preserve some equivalence // classes. Add support for such cases. - let mut init = eqps[0].clone(); + let mut iter = eqps.into_iter(); + let Some(mut acc) = iter.next() else { + return internal_err!( + "Cannot calculate EquivalenceProperties for a union with no inputs" + ); + }; + // Harmonize the schema of the init with the schema of the union: - if !init.schema.eq(&schema) { - init = init.with_new_schema(schema)?; + if !acc.schema.eq(&schema) { + acc = acc.with_new_schema(schema)?; + } + // Fold in the rest of the EquivalenceProperties: + for props in iter { + acc = calculate_union_binary(acc, props)?; + } + Ok(acc) +} + +#[derive(Debug)] +enum AddedOrdering { + /// The ordering was added to the in progress result + Yes, + /// The ordering was not added + No(LexOrdering), +} + +/// Builds valid output orderings of a `UnionExec` +#[derive(Debug)] +struct UnionEquivalentOrderingBuilder { + orderings: Vec, +} + +impl UnionEquivalentOrderingBuilder { + fn new() -> Self { + Self { orderings: vec![] } + } + + /// Add all orderings from `orderings` that satisfy `properties`, + /// potentially augmented with`constants`. + /// + /// Note: any column that is known to be constant can be inserted into the + /// ordering without changing its meaning + /// + /// For example: + /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` + /// * `properties` has required ordering `[a ASC, b ASC]` + /// + /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was + /// in the sort order and `b` was a constant). + fn add_satisfied_orderings( + &mut self, + orderings: impl IntoIterator, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) { + for mut ordering in orderings.into_iter() { + // Progressively shorten the ordering to search for a satisfied prefix: + loop { + match self.try_add_ordering(ordering, constants, properties) { + AddedOrdering::Yes => break, + AddedOrdering::No(o) => { + ordering = o; + ordering.pop(); + } + } + } + } + } + + /// Adds `ordering`, potentially augmented with constants, if it satisfies + /// the target `properties` properties. + /// + /// Returns + /// + /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or + /// augmented), or was empty. + /// + /// * [`AddedOrdering::No`] if the ordering was not added + fn try_add_ordering( + &mut self, + ordering: LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> AddedOrdering { + if ordering.is_empty() { + AddedOrdering::Yes + } else if constants.is_empty() && properties.ordering_satisfy(&ordering) { + // If the ordering satisfies the target properties, no need to + // augment it with constants. + self.orderings.push(ordering); + AddedOrdering::Yes + } else { + // Did not satisfy target properties, try and augment with constants + // to match the properties + if self.try_find_augmented_ordering(&ordering, constants, properties) { + AddedOrdering::Yes + } else { + AddedOrdering::No(ordering) + } + } + } + + /// Attempts to add `constants` to `ordering` to satisfy the properties. + /// + /// returns true if any orderings were added, false otherwise + fn try_find_augmented_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> bool { + // can't augment if there is nothing to augment with + if constants.is_empty() { + return false; + } + let start_num_orderings = self.orderings.len(); + + // for each equivalent ordering in properties, try and augment + // `ordering` it with the constants to match + for existing_ordering in &properties.oeq_class.orderings { + if let Some(augmented_ordering) = self.augment_ordering( + ordering, + constants, + existing_ordering, + &properties.constants, + ) { + if !augmented_ordering.is_empty() { + assert!(properties.ordering_satisfy(&augmented_ordering)); + self.orderings.push(augmented_ordering); + } + } + } + + self.orderings.len() > start_num_orderings + } + + /// Attempts to augment the ordering with constants to match the + /// `existing_ordering` + /// + /// Returns Some(ordering) if an augmented ordering was found, None otherwise + fn augment_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + existing_ordering: &LexOrdering, + existing_constants: &[ConstExpr], + ) -> Option { + let mut augmented_ordering = vec![]; + let mut sort_expr_iter = ordering.iter().peekable(); + let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); + + // walk in parallel down the two orderings, trying to match them up + while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() + { + // If the next expressions are equal, add the next match + // otherwise try and match with a constant + if let Some(expr) = + advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + { + augmented_ordering.push(expr); + } else { + // no match, can't continue the ordering, return what we have + break; + } + } + + Some(augmented_ordering) + } + + fn build(self) -> Vec { + self.orderings } - eqps.into_iter() - .skip(1) - .try_fold(init, calculate_union_binary) +} + +/// Advances two iterators in parallel +/// +/// If the next expressions are equal, the iterators are advanced and returns +/// the matched expression . +/// +/// Otherwise, the iterators are left unchanged and return `None` +fn advance_if_match( + iter1: &mut Peekable>, + iter2: &mut Peekable>, +) -> Option { + if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) + { + iter1.next().unwrap(); + iter2.next().cloned() + } else { + None + } +} + +/// Advances the iterator with a constant +/// +/// If the next expression matches one of the constants, advances the iterator +/// returning the matched expression +/// +/// Otherwise, the iterator is left unchanged and returns `None` +fn advance_if_matches_constant( + iter: &mut Peekable>, + constants: &[ConstExpr], +) -> Option { + let expr = iter.peek()?; + let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + iter.next(); + Some(found_expr) } #[cfg(test)] @@ -1695,16 +2101,13 @@ mod tests { use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, - create_random_schema, create_test_params, create_test_schema, - generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + create_test_params, create_test_schema, output_schema, }; use crate::expressions::{col, BinaryExpr, Column}; - use crate::utils::tests::TestScalarUDF; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, TimeUnit}; - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; + use datafusion_expr::Operator; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1755,6 +2158,51 @@ mod tests { Ok(()) } + #[test] + fn project_equivalence_properties_test_multi() -> Result<()> { + // test multiple input orderings with equivalence properties + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int64, true), + ])); + + let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); + // add equivalent ordering [a, b, c, d] + input_properties.add_new_ordering(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("b", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("d", &input_schema), + ]); + + // add equivalent ordering [a, c, b, d] + input_properties.add_new_ordering(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("b", &input_schema), // NB b and c are swapped + parse_sort_expr("d", &input_schema), + ]); + + // simply project all the columns in order + let proj_exprs = vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + (col("c", &input_schema)?, "c".to_string()), + (col("d", &input_schema)?, "d".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let out_properties = input_properties.project(&projection_mapping, input_schema); + + assert_eq!( + out_properties.to_string(), + "order: [[a@0 ASC,c@2 ASC,b@1 ASC,d@3 ASC], [a@0 ASC,b@1 ASC,c@2 ASC,d@3 ASC]]" + ); + + Ok(()) + } + #[test] fn test_join_equivalence_properties() -> Result<()> { let schema = create_test_schema()?; @@ -2170,83 +2618,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } #[test] fn test_find_longest_permutation() -> Result<()> { // Schema satisfies following orderings: @@ -2708,379 +3079,503 @@ mod tests { )) } - #[tokio::test] - async fn test_union_equivalence_properties_multi_children() -> Result<()> { - let schema = create_test_schema()?; + #[test] + fn test_union_equivalence_properties_multi_children_1() { + let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); - let test_cases = vec![ - // --------- TEST CASE 1 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b", "c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![vec!["a", "b"]], - ), - // --------- TEST CASE 2 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b", "c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected - vec![vec!["a", "b", "c"]], - ), - // --------- TEST CASE 3 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1", "c1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["a2", "b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_2() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b", "c"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_3() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_4() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["b2", "c2"]], &schema3) + .with_expected_sort(vec![]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_5() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_common_constants() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [b, c] + vec![vec!["a"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [b ASC], const [a, c] + vec![vec!["b"]], + vec!["a", "c"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union expected orderings: [[a ASC], [b ASC]], const [c] + vec![vec!["a"], vec!["b"]], + vec!["c"], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_prefix() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, b ASC], const [] vec![vec!["a", "b"]], - ), - // --------- TEST CASE 4 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1"]], - Arc::clone(&schema2), - ), - // Children 3 - ( - // Orderings - vec![vec!["b2", "c2"]], - Arc::clone(&schema3), - ), - ], - // Expected vec![], - ), - // --------- TEST CASE 5 ---------- - ( - vec![ - // Children 1 - ( - // Orderings - vec![vec!["a", "b"], vec!["c"]], - Arc::clone(&schema), - ), - // Children 2 - ( - // Orderings - vec![vec!["a1", "b1"], vec!["c1"]], - Arc::clone(&schema2), - ), - ], - // Expected - vec![vec!["a", "b"], vec!["c"]], - ), - ]; - for (children, expected) in test_cases { - let children_eqs = children - .iter() - .map(|(orderings, schema)| { - let orderings = orderings - .iter() - .map(|ordering| { - ordering - .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect::>() - }) - .collect::>(); - EquivalenceProperties::new_with_orderings( - Arc::clone(schema), - &orderings, - ) - }) - .collect::>(); - let actual = calculate_union(children_eqs, Arc::clone(&schema))?; + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + ) + .run() + } - let expected_ordering = expected - .into_iter() - .map(|ordering| { - ordering - .into_iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: SortOptions::default(), - }) - .collect::>() - }) - .collect::>(); - let expected = EquivalenceProperties::new_with_orderings( - Arc::clone(&schema), - &expected_ordering, - ); - assert_eq_properties_same( - &actual, - &expected, - format!("expected: {:?}, actual: {:?}", expected, actual), - ); - } - Ok(()) + #[test] + fn test_union_equivalence_properties_constants_asc_desc_mismatch() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC], const [] + vec![vec!["a DESC"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union doesn't have any ordering or constant + vec![], + vec![], + ) + .run() } - #[tokio::test] - async fn test_union_equivalence_properties_binary() -> Result<()> { - let schema = create_test_schema()?; + #[test] + fn test_union_equivalence_properties_constants_different_schemas() { + let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_a1 = &col("a1", &schema2)?; - let col_b1 = &col("b1", &schema2)?; - let options = SortOptions::default(); - let options_desc = !SortOptions::default(); - let test_cases = [ - //-----------TEST CASE 1----------// - ( - ( - // First child orderings - vec![ - // [a ASC] - (vec![(col_a, options)]), - ], - // First child constants - vec![col_b, col_c], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [b ASC] - (vec![(col_b, options)]), - ], - // Second child constants - vec![col_a, col_c], - Arc::clone(&schema), - ), - ( - // Union expected orderings - vec![ - // [a ASC] - vec![(col_a, options)], - // [b ASC] - vec![(col_b, options)], - ], - // Union - vec![col_c], - ), - ), - //-----------TEST CASE 2----------// - // Meet ordering between [a ASC], [a ASC, b ASC] should be [a ASC] - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, options), (col_b, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Union orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - ), - ), - //-----------TEST CASE 3----------// - // Meet ordering between [a ASC], [a DESC] should be [] - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a DESC] - vec![(col_a, options_desc)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Union doesn't have any ordering - vec![], - // No constant - vec![], - ), - ), - //-----------TEST CASE 4----------// - // Meet ordering between [a ASC], [a1 ASC, b1 ASC] should be [a ASC] - // Where a, and a1 ath the same index for their corresponding schemas. - ( - ( - // First child orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - Arc::clone(&schema), - ), - ( - // Second child orderings - vec![ - // [a1 ASC, b1 ASC] - vec![(col_a1, options), (col_b1, options)], - ], - // No constant - vec![], - Arc::clone(&schema2), - ), - ( - // Union orderings - vec![ - // [a ASC] - vec![(col_a, options)], - ], - // No constant - vec![], - ), - ), - ]; + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a1 ASC, b1 ASC], const [] + vec![vec!["a1", "b1"]], + vec![], + &schema2, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [a ASC] + // + // Note that a, and a1 are at the same index for their + // corresponding schemas. + vec![vec!["a"]], + vec![], + ) + .run() + } - for ( - test_idx, - ( - (first_child_orderings, first_child_constants, first_schema), - (second_child_orderings, second_child_constants, second_schema), - (union_orderings, union_constants), - ), - ) in test_cases.iter().enumerate() - { - let first_orderings = first_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let first_constants = first_child_constants - .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) - .collect::>(); - let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema)); - lhs = lhs.with_constants(first_constants); - lhs.add_new_orderings(first_orderings); + #[test] + fn test_union_equivalence_properties_constants_fill_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], + vec![], + ) + .run() + } - let second_orderings = second_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let second_constants = second_child_constants + #[test] + fn test_union_equivalence_properties_constants_no_fill_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [d] // some other constant + vec![vec!["a", "c"]], + vec!["d"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a]] (only a is constant) + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_some_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [c ASC], const [a, b] // some other constant + vec![vec!["c"]], + vec!["a", "b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC, b], const [] + vec![vec!["a DESC", "b"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a, b]] (can fill in the a/b with constants) + vec![vec!["a DESC", "b"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b DESC", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_gap_fill_symmetric() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC, b ASC, d ASC], const [c] + vec![vec!["a", "b", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, c, d] + // [a, c, b, d] + vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_gap_fill_and_common() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a DESC, d ASC], const [b, c] + vec![vec!["a DESC", "d"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a DESC, c ASC, d ASC], const [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a DESC, c, d] [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_middle_desc() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // NB `b DESC` in the first child + // + // First child: [a ASC, b DESC, d ASC], const [c] + vec![vec!["a", "b DESC", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, d] (c constant) + // [a, c, d] (b constant) + vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], + vec![], + ) + .run() + } + + // TODO tests with multiple constants + + #[derive(Debug)] + struct UnionEquivalenceTest { + /// The schema of the output of the Union + output_schema: SchemaRef, + /// The equivalence properties of each child to the union + child_properties: Vec, + /// The expected output properties of the union. Must be set before + /// running `build` + expected_properties: Option, + } + + impl UnionEquivalenceTest { + fn new(output_schema: &SchemaRef) -> Self { + Self { + output_schema: Arc::clone(output_schema), + child_properties: vec![], + expected_properties: None, + } + } + + /// Add a union input with the specified orderings + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_child_sort( + mut self, + orderings: Vec>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, vec![], schema); + self.child_properties.push(properties); + self + } + + /// Add a union input with the specified orderings and constant + /// equivalences + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants` + fn with_child_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, constants, schema); + self.child_properties.push(properties); + self + } + + /// Set the expected output sort order for the union of the children + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_expected_sort(mut self, orderings: Vec>) -> Self { + let properties = self.make_props(orderings, vec![], &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// Set the expected output sort order and constant expressions for the + /// union of the children + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants`. + fn with_expected_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + ) -> Self { + let properties = self.make_props(orderings, constants, &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// compute the union's output equivalence properties from the child + /// properties, and compare them to the expected properties + fn run(self) { + let Self { + output_schema, + child_properties, + expected_properties, + } = self; + + let expected_properties = + expected_properties.expect("expected_properties not set"); + + // try all permutations of the children + // as the code treats lhs and rhs differently + for child_properties in child_properties .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) - .collect::>(); - let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema)); - rhs = rhs.with_constants(second_constants); - rhs.add_new_orderings(second_orderings); + .cloned() + .permutations(child_properties.len()) + { + println!("--- permutation ---"); + for c in &child_properties { + println!("{c}"); + } + let actual_properties = + calculate_union(child_properties, Arc::clone(&output_schema)) + .expect("failed to calculate union equivalence properties"); + assert_eq_properties_same( + &actual_properties, + &expected_properties, + format!( + "expected: {expected_properties:?}\nactual: {actual_properties:?}" + ), + ); + } + } - let union_expected_orderings = union_orderings + /// Make equivalence properties for the specified columns named in orderings and constants + /// + /// orderings: strings formatted like `"a"` or `"a DESC"`. See [`parse_sort_expr`] + /// constants: strings formatted like `"a"`. + fn make_props( + &self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> EquivalenceProperties { + let orderings = orderings .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) + .map(|ordering| { + ordering + .iter() + .map(|name| parse_sort_expr(name, schema)) + .collect::>() + }) .collect::>(); - let union_constants = union_constants + + let constants = constants .iter() - .map(|expr| ConstExpr::new(Arc::clone(expr))) + .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) .collect::>(); - let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); - union_expected_eq = union_expected_eq.with_constants(union_constants); - union_expected_eq.add_new_orderings(union_expected_orderings); - let actual_union_eq = calculate_union_binary(lhs, rhs)?; - let err_msg = format!( - "Error in test id: {:?}, test case: {:?}", - test_idx, test_cases[test_idx] - ); - assert_eq_properties_same(&actual_union_eq, &union_expected_eq, err_msg); + EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) + .with_constants(constants) } - Ok(()) } fn assert_eq_properties_same( @@ -3091,21 +3586,63 @@ mod tests { // Check whether constants are same let lhs_constants = lhs.constants(); let rhs_constants = rhs.constants(); - assert_eq!(lhs_constants.len(), rhs_constants.len(), "{}", err_msg); for rhs_constant in rhs_constants { assert!( const_exprs_contains(lhs_constants, rhs_constant.expr()), - "{}", - err_msg + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } + assert_eq!( + lhs_constants.len(), + rhs_constants.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); let rhs_orderings = &rhs.oeq_class.orderings; - assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); for rhs_ordering in rhs_orderings { - assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + assert!( + lhs_orderings.contains(rhs_ordering), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); + } + assert_eq!( + lhs_orderings.len(), + rhs_orderings.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); + } + + /// Converts a string to a physical sort expression + /// + /// # Example + /// * `"a"` -> (`"a"`, `SortOptions::default()`) + /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`) + fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr { + let mut parts = name.split_whitespace(); + let name = parts.next().expect("empty sort expression"); + let mut sort_expr = PhysicalSortExpr::new( + col(name, schema).expect("invalid column name"), + SortOptions::default(), + ); + + if let Some(options) = parts.next() { + sort_expr = match options { + "ASC" => sort_expr.asc(), + "DESC" => sort_expr.desc(), + _ => panic!( + "unknown sort options. Expected 'ASC' or 'DESC', got {}", + options + ), + } } + + assert!( + parts.next().is_none(), + "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got '{name}'" + ); + + sort_expr } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 236b24dd4094a..47b04a876b379 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -27,9 +27,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; -use arrow::compute::kernels::comparison::{ - regexp_is_match_utf8, regexp_is_match_utf8_scalar, -}; +use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; @@ -179,7 +177,7 @@ macro_rules! compute_utf8_flag_op { } else { None }; - let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?; + let mut array = $OP(ll, rr, flag.as_ref())?; if $NOT { array = not(&array).unwrap(); } @@ -188,7 +186,9 @@ macro_rules! compute_utf8_flag_op { } macro_rules! binary_string_array_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, + // the query can be optimized in such a way that operands will be dicts, so we need to support it here let result: Result> = match $LEFT.data_type() { DataType::Utf8View | DataType::Utf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) @@ -196,6 +196,27 @@ macro_rules! binary_string_array_flag_op_scalar { DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) }, + DataType::Dictionary(_, _) => { + let values = $LEFT.as_any_dictionary().values(); + + match values.data_type() { + DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), + other => internal_err!( + "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", + other, stringify!($OP) + ), + }.map( + // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before + |evaluated_values| downcast_dictionary_array! { + $LEFT => { + let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); + Arc::new(unpacked_dict) as _ + }, + _ => unreachable!(), + } + ) + }, other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", other, stringify!($OP) @@ -213,20 +234,32 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { - let flag = $FLAG.then_some("i"); - let mut array = - paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - } else { - internal_err!( + let string_value = match $RIGHT { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + ScalarValue::Dictionary(_, value) => { + match *value { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + other => return internal_err!( + "compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'", + other, stringify!($OP) + ) + } + }, + _ => return internal_err!( "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", $RIGHT, stringify!($OP) ) + + }; + + let flag = $FLAG.then_some("i"); + let mut array = + paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; + if $NOT { + array = not(&array).unwrap(); } + + Ok(Arc::new(array)) }}; } @@ -431,7 +464,7 @@ impl PhysicalExpr for BinaryExpr { // end-points of its children. Ok(Some(vec![])) } - } else if self.op.is_comparison_operator() { + } else if self.op.supports_propagation() { Ok( propagate_comparison(&self.op, interval, left_interval, right_interval)? .map(|(left, right)| vec![left, right]), diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index 1f9cfed1a44fa..c0685c6decde7 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -24,7 +24,7 @@ use arrow::compute::kernels::bitwise::{ bitwise_xor, bitwise_xor_scalar, }; use arrow::datatypes::DataType; -use datafusion_common::internal_err; +use datafusion_common::plan_err; use datafusion_common::{Result, ScalarValue}; use arrow_schema::ArrowError; @@ -70,7 +70,7 @@ macro_rules! create_dyn_kernel { DataType::UInt64 => { call_bitwise_kernel!(left, right, $KERNEL, UInt64Array) } - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) @@ -116,7 +116,7 @@ macro_rules! create_dyn_scalar_kernel { DataType::UInt16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), DataType::UInt32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), DataType::UInt64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ffb431b200f28..981e49d73750c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1096,16 +1096,15 @@ mod tests { let expr2 = Arc::clone(&expr) .transform(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -1117,16 +1116,15 @@ mod tests { let expr3 = Arc::clone(&expr) .transform_down(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5621473c4fdb1..457c47097a19a 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -693,7 +693,7 @@ mod tests { let result = cast( col("a", &schema).unwrap(), &schema, - DataType::Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::MonthDayNano), ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 4aad959584ac4..3e2d49e9fa693 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -107,7 +107,7 @@ impl std::fmt::Display for Column { impl PhysicalExpr for Column { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 0a3e5fcefcf6a..cf57ce3e0e21a 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1102,7 +1102,7 @@ mod tests { let mut phy_exprs = vec![ lit(1i64), expressions::cast(lit(2i32), &schema, DataType::Int64)?, - expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?, + try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); @@ -1130,7 +1130,7 @@ mod tests { try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); // column - phy_exprs.push(expressions::col("a", &schema)?); + phy_exprs.push(col("a", &schema)?); assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); Ok(()) diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 58559352d44c0..cbab7d0c9d1fc 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -73,7 +73,7 @@ impl PhysicalExpr for IsNotNullExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => { - let is_not_null = super::is_null::compute_is_not_null(array)?; + let is_not_null = arrow::compute::is_not_null(&array)?; Ok(ColumnarValue::Array(Arc::new(is_not_null))) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3cdb49bcab42f..1c8597d3fdea8 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -20,14 +20,10 @@ use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use arrow::compute; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use arrow_array::{Array, ArrayRef, BooleanArray, Int8Array, UnionArray}; -use arrow_buffer::{BooleanBuffer, ScalarBuffer}; -use arrow_ord::cmp; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -77,9 +73,9 @@ impl PhysicalExpr for IsNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => { - Ok(ColumnarValue::Array(Arc::new(compute_is_null(array)?))) - } + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( + arrow::compute::is_null(&array)?, + ))), ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(scalar.is_null())), )), @@ -103,65 +99,6 @@ impl PhysicalExpr for IsNullExpr { } } -/// workaround , -/// this can be replaced with a direct call to `arrow::compute::is_null` once it's fixed. -pub(crate) fn compute_is_null(array: ArrayRef) -> Result { - if let Some(union_array) = array.as_any().downcast_ref::() { - if let Some(offsets) = union_array.offsets() { - dense_union_is_null(union_array, offsets) - } else { - sparse_union_is_null(union_array) - } - } else { - compute::is_null(array.as_ref()).map_err(Into::into) - } -} - -/// workaround , -/// this can be replaced with a direct call to `arrow::compute::is_not_null` once it's fixed. -pub(crate) fn compute_is_not_null(array: ArrayRef) -> Result { - if array.as_any().is::() { - compute::not(&compute_is_null(array)?).map_err(Into::into) - } else { - compute::is_not_null(array.as_ref()).map_err(Into::into) - } -} - -fn dense_union_is_null( - union_array: &UnionArray, - offsets: &ScalarBuffer, -) -> Result { - let child_arrays = (0..union_array.type_names().len()) - .map(|type_id| { - compute::is_null(&union_array.child(type_id as i8)).map_err(Into::into) - }) - .collect::>>()?; - - let buffer: BooleanBuffer = offsets - .iter() - .zip(union_array.type_ids()) - .map(|(offset, type_id)| child_arrays[*type_id as usize].value(*offset as usize)) - .collect(); - - Ok(BooleanArray::new(buffer, None)) -} - -fn sparse_union_is_null(union_array: &UnionArray) -> Result { - let type_ids = Int8Array::new(union_array.type_ids().clone(), None); - - let mut union_is_null = - BooleanArray::new(BooleanBuffer::new_unset(union_array.len()), None); - for type_id in 0..union_array.type_names().len() { - let type_id = type_id as i8; - let union_is_child = cmp::eq(&type_ids, &Int8Array::new_scalar(type_id))?; - let child = union_array.child(type_id); - let child_array_is_null = compute::is_null(&child)?; - let child_is_null = compute::and(&union_is_child, &child_array_is_null)?; - union_is_null = compute::or(&union_is_null, &child_is_null)?; - } - Ok(union_is_null) -} - impl PartialEq for IsNullExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,7 +121,7 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; - use arrow_array::{Float64Array, Int32Array}; + use arrow_array::{Array, Float64Array, Int32Array, UnionArray}; use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; @@ -243,8 +180,7 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let array_ref = Arc::new(array) as ArrayRef; - let result = compute_is_null(array_ref).unwrap(); + let result = arrow::compute::is_null(&array).unwrap(); let expected = &BooleanArray::from(vec![false, true, false, false, true, true, false]); @@ -272,8 +208,7 @@ mod tests { UnionArray::try_new(union_fields(), type_ids, Some(offsets), children) .unwrap(); - let array_ref = Arc::new(array) as ArrayRef; - let result = compute_is_null(array_ref).unwrap(); + let result = arrow::compute::is_null(&array).unwrap(); let expected = &BooleanArray::from(vec![false, true, false, true, false, true]); assert_eq!(expected, &result); diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 67af634aeb7a0..a5a59399191f5 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -36,11 +36,7 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; -pub use crate::window::cume_dist::{cume_dist, CumeDist}; -pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; -pub use crate::window::ntile::Ntile; -pub use crate::window::rank::{dense_rank, percent_rank, rank, Rank, RankType}; pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index b5ebc250cb896..399ebde9f726d 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -257,7 +257,7 @@ mod tests { #[test] fn test_negation_valid_types() -> Result<()> { let negatable_types = [ - DataType::Int8, + Int8, DataType::Timestamp(TimeUnit::Second, None), DataType::Interval(IntervalUnit::YearMonth), ]; diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index badb00659576c..669fbb7cdecf8 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -18,15 +18,21 @@ use super::Literal; use arrow::array::ArrayData; use arrow_array::{ - Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, + Array, ArrayAccessor, BooleanArray, LargeStringArray, RecordBatch, StringArray, + StringViewArray, }; use arrow_buffer::BooleanBufferBuilder; use arrow_schema::{DataType, Schema}; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; use regex::Regex; -use std::{any::Any, hash::Hash, sync::Arc}; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter, Result as FmtResult}, + hash::{Hash, Hasher}, + sync::Arc, +}; /// ScalarRegexMatchExpr /// Only used when evaluating regexp matching with literal pattern. @@ -81,7 +87,7 @@ impl ScalarRegexMatchExpr { } /// Compile regex pattern - fn compile(&mut self) -> datafusion_common::Result<()> { + fn compile(&mut self) -> DFResult<()> { let scalar_pattern = self.pattern .as_any() @@ -108,16 +114,13 @@ impl ScalarRegexMatchExpr { self.compiled = Some(compiled); }) .map_err(|err| { - datafusion_common::DataFusionError::Internal(format!( - "Failed to compile regex: {}", - err - )) + DataFusionError::Internal(format!("Failed to compile regex: {}", err)) }), Some(None) => { self.compiled = None; Ok(()) } - None => Err(datafusion_common::DataFusionError::Internal(format!( + None => Err(DataFusionError::Internal(format!( "Regex pattern({}) isn't literal string", self.pattern ))), @@ -137,10 +140,7 @@ impl ScalarRegexMatchExpr { impl ScalarRegexMatchExpr { /// Evaluate the scalar regex match expression match array value - fn evaluate_array( - &self, - array: &Arc, - ) -> datafusion_common::Result { + fn evaluate_array(&self, array: &Arc) -> DFResult { macro_rules! downcast_string_array { ($ARRAY:expr, $ARRAY_TYPE:ident, $ERR_MSG:expr) => { &($ARRAY @@ -175,10 +175,7 @@ impl ScalarRegexMatchExpr { } /// Evaluate the scalar regex match expression match scalar value - fn evaluate_scalar( - &self, - scalar: &ScalarValue, - ) -> datafusion_common::Result { + fn evaluate_scalar(&self, scalar: &ScalarValue) -> DFResult { match scalar { ScalarValue::Null | ScalarValue::Utf8(None) @@ -200,8 +197,8 @@ impl ScalarRegexMatchExpr { } } -impl std::hash::Hash for ScalarRegexMatchExpr { - fn hash(&self, state: &mut H) { +impl Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { self.negated.hash(state); self.case_insensitive.hash(state); self.expr.hash(state); @@ -209,7 +206,7 @@ impl std::hash::Hash for ScalarRegexMatchExpr { } } -impl std::cmp::PartialEq for ScalarRegexMatchExpr { +impl PartialEq for ScalarRegexMatchExpr { fn eq(&self, other: &Self) -> bool { self.negated.eq(&other.negated) && self.case_insensitive.eq(&self.case_insensitive) @@ -218,8 +215,8 @@ impl std::cmp::PartialEq for ScalarRegexMatchExpr { } } -impl std::fmt::Debug for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("ScalarRegexMatchExpr") .field("negated", &self.negated) .field("case_insensitive", &self.case_insensitive) @@ -229,35 +226,26 @@ impl std::fmt::Debug for ScalarRegexMatchExpr { } } -impl std::fmt::Display for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) } } impl PhysicalExpr for ScalarRegexMatchExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } - fn data_type( - &self, - _: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn data_type(&self, _: &Schema) -> DFResult { Ok(DataType::Boolean) } - fn nullable( - &self, - input_schema: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> DFResult { Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) } - fn evaluate( - &self, - batch: &arrow_array::RecordBatch, - ) -> datafusion_common::Result { + fn evaluate(&self, batch: &RecordBatch) -> DFResult { self.expr .evaluate(batch) .and_then(|lhs| { @@ -274,14 +262,14 @@ impl PhysicalExpr for ScalarRegexMatchExpr { .map(ColumnarValue::Array) } - fn children(&self) -> Vec<&std::sync::Arc> { + fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } fn with_new_children( - self: std::sync::Arc, - children: Vec>, - ) -> datafusion_common::Result> { + self: Arc, + children: Vec>, + ) -> DFResult> { Ok(Arc::new(ScalarRegexMatchExpr::new( self.negated, self.case_insensitive, @@ -290,7 +278,7 @@ impl PhysicalExpr for ScalarRegexMatchExpr { ))) } - fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } @@ -310,7 +298,7 @@ fn array_regexp_match( array: &dyn ArrayAccessor, regex: &Regex, negated: bool, -) -> datafusion_common::Result { +) -> DFResult { let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); let mut buffer_builder = BooleanBufferBuilder::new(array.len()); @@ -344,10 +332,7 @@ fn array_regexp_match( bool_array .map_err(|err| { - datafusion_common::DataFusionError::Execution(format!( - "Failed to evaluate regex: {}", - err - )) + DataFusionError::Execution(format!("Failed to evaluate regex: {}", err)) }) .map(|bool_array| ColumnarValue::Array(Arc::new(bool_array))) } @@ -359,7 +344,7 @@ pub fn scalar_regex_match( expr: Arc, pattern: Arc, input_schema: &Schema, -) -> datafusion_common::Result> { +) -> DFResult> { let valid_data_type = |data_type: &DataType| { if !matches!( data_type, @@ -390,11 +375,8 @@ pub fn scalar_regex_match( mod tests { use super::*; use crate::expressions::{col, lit}; - use arrow::record_batch::RecordBatch; - use arrow_array::ArrayRef; - use arrow_array::NullArray; - use arrow_schema::Field; - use arrow_schema::Schema; + use arrow_array::{ArrayRef, NullArray, RecordBatch}; + use arrow_schema::{Field, Schema}; use rstest::rstest; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index cb7221e7fa151..590efd5779638 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -57,7 +57,7 @@ impl std::fmt::Display for UnKnownColumn { impl PhysicalExpr for UnKnownColumn { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index f05ac3624b8e2..8084a52c78d80 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use super::utils::{ @@ -128,12 +129,11 @@ impl ExprIntervalGraph { /// Estimate size of bytes including `Self`. pub fn size(&self) -> usize { let node_memory_usage = self.graph.node_count() - * (std::mem::size_of::() - + std::mem::size_of::()); - let edge_memory_usage = self.graph.edge_count() - * (std::mem::size_of::() + std::mem::size_of::() * 2); + * (size_of::() + size_of::()); + let edge_memory_usage = + self.graph.edge_count() * (size_of::() + size_of::() * 2); - std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage + size_of_val(self) + node_memory_usage + edge_memory_usage } } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 46185712413ef..e7c2b4119c5ae 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -27,7 +27,6 @@ pub mod binary_map { pub mod equivalence; pub mod expressions; pub mod intervals; -pub mod math_expressions; mod partitioning; mod physical_expr; pub mod planner; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs deleted file mode 100644 index 503565b1e2613..0000000000000 --- a/datafusion/physical-expr/src/math_expressions.rs +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math expressions - -use std::any::type_name; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::array::{BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow_array::Array; - -use datafusion_common::exec_err; -use datafusion_common::{DataFusionError, Result}; - -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} from {} to {}", - $NAME, - $ARG.data_type(), - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - -/// Isnan SQL function -pub fn isnan(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { f64::is_nan } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { f32::is_nan } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function isnan"), - } -} - -#[cfg(test)] -mod tests { - - use datafusion_common::cast::as_boolean_array; - - use super::*; - - #[test] - fn test_isnan_f64() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 1.0, - f64::NAN, - 3.0, - -f64::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_isnan_f32() { - let args: Vec = vec![Arc::new(Float32Array::from(vec![ - 1.0, - f32::NAN, - 3.0, - f32::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 01f72a8efd9a5..98c0c864b9f70 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -121,8 +121,8 @@ pub enum Partitioning { UnknownPartitioning(usize), } -impl fmt::Display for Partitioning { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for Partitioning { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Partitioning::RoundRobinBatch(size) => write!(f, "RoundRobinBatch({size})"), Partitioning::Hash(phy_exprs, size) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 130c335d1c95e..ab53106f60598 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -39,7 +39,8 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, DFSchema, Result}; +use arrow_array::Array; +use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; @@ -140,15 +141,23 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?; // evaluate the function - let output = match self.args.is_empty() { - true => self.fun.invoke_no_args(batch.num_rows()), - false => self.fun.invoke(&inputs), - }?; + let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { - return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), array.len()); + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 + && !inputs.is_empty() + && inputs + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), array.len()) + }; } } Ok(output) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index cd1597217c83a..fbb59cc92fa05 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -93,18 +93,18 @@ impl LiteralGuarantee { /// Create a new instance of the guarantee if the provided operator is /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to /// create these structures from an predicate (boolean expression). - fn try_new<'a>( + fn new<'a>( column_name: impl Into, guarantee: Guarantee, literals: impl IntoIterator, - ) -> Option { + ) -> Self { let literals: HashSet<_> = literals.into_iter().cloned().collect(); - Some(Self { + Self { column: Column::from_name(column_name), guarantee, literals, - }) + } } /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` @@ -338,13 +338,10 @@ impl<'a> GuaranteeBuilder<'a> { // This is a new guarantee let new_values: HashSet<_> = new_values.into_iter().collect(); - if let Some(guarantee) = - LiteralGuarantee::try_new(col.name(), guarantee, new_values) - { - // add it to the list of guarantees - self.guarantees.push(Some(guarantee)); - self.map.insert(key, self.guarantees.len() - 1); - } + let guarantee = LiteralGuarantee::new(col.name(), guarantee, new_values); + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); } self @@ -851,7 +848,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() + LiteralGuarantee::new(column, Guarantee::In, literals.iter()) } /// Guarantee that the expression is true if the column is NOT any of the specified values @@ -861,7 +858,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() + LiteralGuarantee::new(column, Guarantee::NotIn, literals.iter()) } // Schema for testing diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7f..4bd022975ac36 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -86,6 +86,10 @@ pub fn map_columns_before_projection( parent_required: &[Arc], proj_exprs: &[(Arc, String)], ) -> Vec> { + if parent_required.is_empty() { + // No need to build mapping. + return vec![]; + } let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index d012fef93b675..3fe5d842dfd15 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct PlainAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct PlainAggregateWindowExpr { impl PlainAggregateWindowExpr { /// Create a new aggregate window function expression pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -137,14 +137,14 @@ impl WindowExpr for PlainAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs deleted file mode 100644 index 9720187ea83dd..0000000000000 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `cume_dist` that can evaluated -//! at runtime during query execution - -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::array::Float64Array; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::iter; -use std::ops::Range; -use std::sync::Arc; - -/// CumeDist calculates the cume_dist in the window function with order by -#[derive(Debug)] -pub struct CumeDist { - name: String, - /// Output data type - data_type: DataType, -} - -/// Create a cume_dist window function -pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist { - CumeDist { - name, - data_type: data_type.clone(), - } -} - -impl BuiltInWindowFunctionExpr for CumeDist { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(CumeDistEvaluator {})) - } -} - -#[derive(Debug)] -pub(crate) struct CumeDistEvaluator; - -impl PartitionEvaluator for CumeDistEvaluator { - fn evaluate_all_with_rank( - &self, - num_rows: usize, - ranks_in_partition: &[Range], - ) -> Result { - let scalar = num_rows as f64; - let result = Float64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(0_u64, |acc, range| { - let len = range.end - range.start; - *acc += len as u64; - let value: f64 = (*acc as f64) / scalar; - let result = iter::repeat(value).take(len); - Some(result) - }) - .flatten(), - ); - Ok(Arc::new(result)) - } - - fn include_rank(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_common::cast::as_float64_array; - - fn test_i32_result( - expr: &CumeDist, - num_rows: usize, - ranks: Vec>, - expected: Vec, - ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - #[allow(clippy::single_range_in_vec_init)] - fn test_cume_dist() -> Result<()> { - let r = cume_dist("arr".into(), &DataType::Float64); - - let expected = vec![0.0; 0]; - test_i32_result(&r, 0, vec![], expected)?; - - let expected = vec![1.0; 1]; - test_i32_result(&r, 1, vec![0..1], expected)?; - - let expected = vec![1.0; 2]; - test_i32_result(&r, 2, vec![0..2], expected)?; - - let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; - - let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 2aeb053331027..3c37fff7a1ba6 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -18,11 +18,7 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; -pub(crate) mod cume_dist; -pub(crate) mod lead_lag; pub(crate) mod nth_value; -pub(crate) mod ntile; -pub(crate) mod rank; mod sliding_aggregate; mod window_expr; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 87c74579c6392..6ec3a23fc5863 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -30,7 +30,7 @@ use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::ScalarValue; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; @@ -86,16 +86,13 @@ impl NthValue { n: i64, ignore_nulls: bool, ) -> Result { - match n { - 0 => exec_err!("NTH_VALUE expects n to be non-zero"), - _ => Ok(Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Nth(n), - ignore_nulls, - }), - } + Ok(Self { + name: name.into(), + expr, + data_type, + kind: NthValueKind::Nth(n), + ignore_nulls, + }) } /// Get the NTH_VALUE kind @@ -188,10 +185,7 @@ impl PartitionEvaluator for NthValueEvaluator { // Negative index represents reverse direction. (n_range >= reverse_index, true) } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => (false, false), } } }; @@ -298,10 +292,7 @@ impl PartitionEvaluator for NthValueEvaluator { ) } } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => ScalarValue::try_from(arr.data_type()), } } } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs deleted file mode 100644 index fb7a7ad84fb70..0000000000000 --- a/datafusion/physical-expr/src/window/ntile.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `ntile` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::{ArrayRef, UInt64Array}; -use arrow::datatypes::Field; -use arrow_schema::{DataType, SchemaRef, SortOptions}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::sync::Arc; - -#[derive(Debug)] -pub struct Ntile { - name: String, - n: u64, - /// Output data type - data_type: DataType, -} - -impl Ntile { - pub fn new(name: String, n: u64, data_type: &DataType) -> Self { - Self { - name, - n, - data_type: data_type.clone(), - } - } - - pub fn get_n(&self) -> u64 { - self.n - } -} - -impl BuiltInWindowFunctionExpr for Ntile { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(NtileEvaluator { n: self.n })) - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in NTILE window function introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } -} - -#[derive(Debug)] -pub(crate) struct NtileEvaluator { - n: u64, -} - -impl PartitionEvaluator for NtileEvaluator { - fn evaluate_all( - &mut self, - _values: &[ArrayRef], - num_rows: usize, - ) -> Result { - let num_rows = num_rows as u64; - let mut vec: Vec = Vec::new(); - let n = u64::min(self.n, num_rows); - for i in 0..num_rows { - let res = i * n / num_rows; - vec.push(res + 1) - } - Ok(Arc::new(UInt64Array::from(vec))) - } -} diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 143d59eb44953..b889ec8c5d984 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct SlidingAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct SlidingAggregateWindowExpr { impl SlidingAggregateWindowExpr { /// Create a new (sliding) aggregate window function expression. pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -121,14 +121,14 @@ impl WindowExpr for SlidingAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), @@ -159,7 +159,10 @@ impl WindowExpr for SlidingAggregateWindowExpr { }) .collect::>(); Some(Arc::new(SlidingAggregateWindowExpr { - aggregate: self.aggregate.with_new_expressions(args, vec![])?, + aggregate: self + .aggregate + .with_new_expressions(args, vec![]) + .map(Arc::new)?, partition_by: partition_bys, order_by: new_order_by, window_frame: Arc::clone(&self.window_frame), diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 8f6f78df8cb85..46c46fab68c54 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -530,19 +530,6 @@ pub enum WindowFn { Aggregate(Box), } -/// State for the RANK(percent_rank, rank, dense_rank) built-in window function. -#[derive(Debug, Clone, Default)] -pub struct RankState { - /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Option>, - /// The index where last_rank_boundary is started - pub last_rank_boundary: usize, - /// Keep the number of entries in current rank - pub current_group_count: usize, - /// Rank number kept from the start - pub n_rank: usize, -} - /// Tag to differentiate special use cases of the NTH_VALUE built-in window function. #[derive(Debug, Copy, Clone)] pub enum NthValueKind { diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index acf3eee105d4d..e7bf4a80fc450 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -32,9 +32,15 @@ rust-version = { workspace = true } workspace = true [dependencies] +arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } itertools = { workspace = true } + +[dev-dependencies] +datafusion-functions-aggregate = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 71f129be984d1..27870c7865f38 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -20,17 +20,15 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::{expressions, ExecutionPlan, Statistics}; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; +use datafusion_physical_plan::{expressions, ExecutionPlan}; use crate::PhysicalOptimizerRule; -use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::AggregateFunctionExpr; /// Optimizer that uses available statistics for aggregate functions #[derive(Default, Debug)] @@ -57,14 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics { let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { - if let Some((non_null_rows, name)) = - take_optimizable_column_and_table_count(expr, &stats) + let field = expr.field(); + let args = expr.expressions(); + let statistics_args = StatisticsArgs { + statistics: &stats, + return_type: field.data_type(), + is_distinct: expr.is_distinct(), + exprs: args.as_slice(), + }; + if let Some((optimizable_statistic, name)) = + take_optimizable_value_from_statistics(&statistics_args, expr) { - projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((min, name)) = take_optimizable_min(expr, &stats) { - projections.push((expressions::lit(min), name.to_owned())); - } else if let Some((max, name)) = take_optimizable_max(expr, &stats) { - projections.push((expressions::lit(max), name.to_owned())); + projections + .push((expressions::lit(optimizable_statistic), name.to_owned())); } else { // TODO: we need all aggr_expr to be resolved (cf TODO fullres) break; @@ -135,160 +138,381 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_and_table_count( +/// If this agg_expr is a max that is exactly defined in the statistics, return it. +fn take_optimizable_value_from_statistics( + statistics_args: &StatisticsArgs, agg_expr: &AggregateFunctionExpr, - stats: &Statistics, ) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if is_non_distinct_count(agg_expr) { - if let Precision::Exact(num_rows) = stats.num_rows { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - agg_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = - exprs[0].as_any().downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - agg_expr.name().to_string(), - )); - } - } - } - } - } - None + let value = agg_expr.fun().value_from_stats(statistics_args); + value.map(|val| (val, agg_expr.name().to_string())) } -/// If this agg_expr is a min that is exactly defined in the statistics, return it. -fn take_optimizable_min( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_min(agg_expr) { - if let Ok(min_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((min_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_min(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].min_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } +#[cfg(test)] +mod tests { + use crate::aggregate_statistics::AggregateStatistics; + use crate::PhysicalOptimizerRule; + use datafusion_common::config::ConfigOptions; + use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; + use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::udaf::AggregateFunctionExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::sync::Arc; + + use datafusion_common::Result; + use datafusion_expr_common::operator::Operator; + + use datafusion_physical_plan::aggregates::PhysicalGroupBy; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::memory::MemoryExec; + + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::{self, cast}; + use datafusion_physical_plan::aggregates::AggregateMode; + + /// Describe the type of aggregate being tested + pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), + } + + impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(Arc::clone(schema)) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .alias(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), } - _ => {} } - } - None -} -/// If this agg_expr is a max that is exactly defined in the statistics, return it. -fn take_optimizable_max( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_max(agg_expr) { - if let Ok(max_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((max_data_type, agg_expr.name().to_string())); - } - } + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_max(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].max_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, } - _ => {} } } - None -} -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_non_distinct_count(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { - return true; + /// Mock data using a MemoryExec which has an exact count statistic + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) } - false -} -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_min(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "min" { - return true; + /// Checks that the count optimization was applied and we still get the right result + async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let plan: Arc = Arc::new(plan); + + let config = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?; + + // A ProjectionExec is a sign that the count optimization was applied + assert!(optimized.as_any().is::()); + + // run both the optimized and nonoptimized plan + let optimized_result = + common::collect(optimized.execute(0, Arc::clone(&task_ctx))?).await?; + let nonoptimized_result = common::collect(plan.execute(0, task_ctx)?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) } - false -} -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_max(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "max" { - return true; + fn check_batch(batch: RecordBatch, agg: &TestAggregate) { + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullabiolity differs + + assert_eq!( + as_int64_array(batch.column(0)).unwrap().values(), + &[agg.expected_count()] + ); } - false -} -// See tests in datafusion/core/tests/physical_optimizer + #[tokio::test] + async fn test_count_partial_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } + + #[tokio::test] + async fn test_count_with_nulls_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(agg.count_expr(&schema))], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } +} diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 4e352e25b52c9..86f7e73e9e359 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -125,7 +125,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { type GroupExprsRef<'a> = ( &'a PhysicalGroupBy, - &'a [AggregateFunctionExpr], + &'a [Arc], &'a [Option>], ); diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 804dd165d335c..c8a28ed0ec0ba 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -19,21 +19,17 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::ExecutionPlan; - -use arrow_schema::DataType; +use crate::PhysicalOptimizerRule; +use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalSortExpr; - -use crate::PhysicalOptimizerRule; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed @@ -48,12 +44,13 @@ impl TopKAggregation { fn transform_agg( aggr: &AggregateExec, - order: &PhysicalSortExpr, + order_by: &str, + order_desc: bool, limit: usize, ) -> Option> { // ensure the sort direction matches aggregate function let (field, desc) = aggr.get_minmax_desc()?; - if desc != order.options.descending { + if desc != order_desc { return None; } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; @@ -66,8 +63,7 @@ impl TopKAggregation { } // ensure the sort is on the same field as the aggregate output - let col = order.expr.as_any().downcast_ref::()?; - if col.name() != field.name() { + if order_by != field.name() { return None; } @@ -92,16 +88,11 @@ impl TopKAggregation { let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; + let order_desc = order.options.descending; + let order = order.expr.as_any().downcast_ref::()?; + let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; - let is_cardinality_preserving = |plan: Arc| { - plan.as_any() - .downcast_ref::() - .is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - }; - let mut cardinality_preserved = true; let closure = |plan: Arc| { if !cardinality_preserved { @@ -109,14 +100,27 @@ impl TopKAggregation { } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it - match Self::transform_agg(aggr, order, limit) { + match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) { None => cardinality_preserved = false, Some(plan) => return Ok(Transformed::yes(plan)), } + } else if let Some(proj) = plan.as_any().downcast_ref::() { + // track renames due to successive projections + for (src_expr, proj_name) in proj.expr() { + let Some(src_col) = src_expr.as_any().downcast_ref::() else { + continue; + }; + if *proj_name == cur_col_name { + cur_col_name = src_col.name().to_string(); + } + } } else { - // or we continue down whitelisted nodes of other types - if !is_cardinality_preserving(Arc::clone(&plan)) { - cardinality_preserved = false; + // or we continue down through types that don't reduce cardinality + match plan.cardinality_effect() { + CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {} + CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => { + cardinality_preserved = false; + } } } Ok(Transformed::no(plan)) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index c3f1b7eb0e95c..7fcd719539ec7 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -51,7 +51,6 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } @@ -69,6 +68,7 @@ rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] +datafusion-functions-aggregate = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index f789af8b8a024..013c027e7306c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -19,6 +19,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// @@ -73,7 +74,7 @@ impl GroupValues for GroupValuesByes { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 1a0cb90a16d47..7379b7a538b49 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -20,6 +20,7 @@ use arrow_array::{Array, ArrayRef, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8View/BinaryView values /// @@ -74,7 +75,7 @@ impl GroupValues for GroupValuesBytesView { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 977b40922f7cb..958a4b58d8004 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -16,14 +16,16 @@ // under the License. use crate::aggregates::group_values::group_column::{ - ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder, + ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn, + PrimitiveGroupValueBuilder, }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; use arrow::compute::cast; use arrow::datatypes::{ - Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, }; use arrow::record_batch::RecordBatch; use arrow_array::{Array, ArrayRef}; @@ -33,10 +35,12 @@ use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; - use hashbrown::raw::RawTable; +use std::mem::size_of; -/// Compare GroupValue Rows column by column +/// A [`GroupValues`] that stores multiple columns of group values. +/// +/// pub struct GroupValuesColumn { /// The output schema schema: SchemaRef, @@ -55,8 +59,13 @@ pub struct GroupValuesColumn { map_size: usize, /// The actual group by values, stored column-wise. Compare from - /// the left to right, each column is stored as `ArrayRowEq`. - /// This is shown faster than the row format + /// the left to right, each column is stored as [`GroupColumn`]. + /// + /// Performance tests showed that this design is faster than using the + /// more general purpose [`GroupValuesRows`]. See the ticket for details: + /// + /// + /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows group_values: Vec>, /// reused buffer to store hashes @@ -112,10 +121,31 @@ impl GroupValuesColumn { | DataType::LargeBinary | DataType::Date32 | DataType::Date64 + | DataType::Utf8View + | DataType::BinaryView ) } } +/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v +/// +/// Arguments: +/// `$v`: the vector to push the new builder into +/// `$nullable`: whether the input can contains nulls +/// `$t`: the primitive type of the builder +/// +macro_rules! instantiate_primitive { + ($v:expr, $nullable:expr, $t:ty) => { + if $nullable { + let b = PrimitiveGroupValueBuilder::<$t, true>::new(); + $v.push(Box::new(b) as _) + } else { + let b = PrimitiveGroupValueBuilder::<$t, false>::new(); + $v.push(Box::new(b) as _) + } + }; +} + impl GroupValues for GroupValuesColumn { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { let n_rows = cols[0].len(); @@ -126,54 +156,22 @@ impl GroupValues for GroupValuesColumn { for f in self.schema.fields().iter() { let nullable = f.is_nullable(); match f.data_type() { - &DataType::Int8 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int16 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Int64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt8 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt16 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::UInt64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } + &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type), + &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type), + &DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type), + &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type), + &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type), + &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type), + &DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type), + &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type), &DataType::Float32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) + instantiate_primitive!(v, nullable, Float32Type) } &DataType::Float64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Date32 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) - } - &DataType::Date64 => { - let b = PrimitiveGroupValueBuilder::::new(nullable); - v.push(Box::new(b) as _) + instantiate_primitive!(v, nullable, Float64Type) } + &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), + &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), &DataType::Utf8 => { let b = ByteGroupValueBuilder::::new(OutputType::Utf8); v.push(Box::new(b) as _) @@ -190,6 +188,14 @@ impl GroupValues for GroupValuesColumn { let b = ByteGroupValueBuilder::::new(OutputType::Binary); v.push(Box::new(b) as _) } + &DataType::Utf8View => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + &DataType::BinaryView => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } dt => { return not_impl_err!("{dt} not supported in GroupValuesColumn") } @@ -345,7 +351,7 @@ impl GroupValues for GroupValuesColumn { self.group_values.clear(); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index a82e6d856c70c..bba59b6d0caa5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -15,37 +15,46 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::BooleanBufferBuilder; +use arrow::array::make_view; use arrow::array::BufferBuilder; +use arrow::array::ByteView; use arrow::array::GenericBinaryArray; use arrow::array::GenericStringArray; use arrow::array::OffsetSizeTrait; use arrow::array::PrimitiveArray; use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; -use arrow::buffer::NullBuffer; use arrow::buffer::OffsetBuffer; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ByteArrayType; +use arrow::datatypes::ByteViewType; use arrow::datatypes::DataType; use arrow::datatypes::GenericBinaryType; -use arrow::datatypes::GenericStringType; +use arrow_array::GenericByteViewArray; +use arrow_buffer::Buffer; use datafusion_common::utils::proxy::VecAllocExt; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow_array::types::GenericStringType; +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; +use std::marker::PhantomData; +use std::mem::{replace, size_of}; use std::sync::Arc; use std::vec; -use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; -/// Trait for group values column-wise row comparison +/// Trait for storing a single column of group values in [`GroupValuesColumn`] /// -/// Implementations of this trait store a in-progress collection of group values +/// Implementations of this trait store an in-progress collection of group values /// (similar to various builders in Arrow-rs) that allow for quick comparison to /// incoming rows. /// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn pub trait GroupColumn: Send + Sync { /// Returns equal if the row stored in this builder at `lhs_row` is equal to /// the row in `array` at `rhs_row` + /// + /// Note that this comparison returns true if both elements are NULL fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; /// Appends the row at `row` in `array` to this builder fn append_val(&mut self, array: &ArrayRef, row: usize); @@ -60,59 +69,62 @@ pub trait GroupColumn: Send + Sync { fn take_n(&mut self, n: usize) -> ArrayRef; } -pub struct PrimitiveGroupValueBuilder { +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct PrimitiveGroupValueBuilder { group_values: Vec, - nulls: Vec, - // whether the array contains at least one null, for fast non-null path - has_null: bool, - nullable: bool, + nulls: MaybeNullBufferBuilder, } -impl PrimitiveGroupValueBuilder +impl PrimitiveGroupValueBuilder where T: ArrowPrimitiveType, { - pub fn new(nullable: bool) -> Self { + /// Create a new `PrimitiveGroupValueBuilder` + pub fn new() -> Self { Self { group_values: vec![], - nulls: vec![], - has_null: false, - nullable, + nulls: MaybeNullBufferBuilder::new(), } } } -impl GroupColumn for PrimitiveGroupValueBuilder { +impl GroupColumn + for PrimitiveGroupValueBuilder +{ fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // non-null fast path - // both non-null - if !self.nullable { - return self.group_values[lhs_row] - == array.as_primitive::().value(rhs_row); - } - - // lhs is non-null - if self.nulls[lhs_row] { - if array.is_null(rhs_row) { - return false; + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; } - - return self.group_values[lhs_row] - == array.as_primitive::().value(rhs_row); + // Otherwise, we need to check their values } - array.is_null(rhs_row) + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } fn append_val(&mut self, array: &ArrayRef, row: usize) { - if self.nullable && array.is_null(row) { - self.group_values.push(T::default_value()); - self.nulls.push(false); - self.has_null = true; + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } } else { - let elem = array.as_primitive::().value(row); - self.group_values.push(elem); - self.nulls.push(true); + self.group_values.push(array.as_primitive::().value(row)); } } @@ -125,48 +137,54 @@ impl GroupColumn for PrimitiveGroupValueBuilder { } fn build(self: Box) -> ArrayRef { - if self.has_null { - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(self.group_values), - Some(NullBuffer::from(self.nulls)), - )) - } else { - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(self.group_values), - None, - )) + let Self { + group_values, + nulls, + } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); } + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(group_values), + nulls, + )) } fn take_n(&mut self, n: usize) -> ArrayRef { - if self.has_null { - let first_n = self.group_values.drain(0..n).collect::>(); - let first_n_nulls = self.nulls.drain(0..n).collect::>(); - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(first_n), - Some(NullBuffer::from(first_n_nulls)), - )) - } else { - let first_n = self.group_values.drain(0..n).collect::>(); - self.nulls.truncate(self.nulls.len() - n); - Arc::new(PrimitiveArray::::new(ScalarBuffer::from(first_n), None)) - } + let first_n = self.group_values.drain(0..n).collect::>(); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(first_n), + first_n_nulls, + )) } } +/// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array pub struct ByteGroupValueBuilder where O: OffsetSizeTrait, { output_type: OutputType, buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct value. These offsets as used + /// Offsets into `buffer` for each distinct value. These offsets as used /// directly to create the final `GenericBinaryArray`. The `i`th string is /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values /// are stored as a zero length string. offsets: Vec, - /// Null indexes in offsets, if `i` is in nulls, `offsets[i]` should be equals to `offsets[i+1]` - nulls: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, } impl ByteGroupValueBuilder @@ -178,7 +196,7 @@ where output_type, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], - nulls: vec![], + nulls: MaybeNullBufferBuilder::new(), } } @@ -188,40 +206,38 @@ where { let arr = array.as_bytes::(); if arr.is_null(row) { - self.nulls.push(self.len()); + self.nulls.append(true); // nulls need a zero length in the offset buffer let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - return; + } else { + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); } - - let value: &[u8] = arr.value(row).as_ref(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); } fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool where B: ByteArrayType, { - // Handle nulls - let is_lhs_null = self.nulls.iter().any(|null_idx| *null_idx == lhs_row); - let arr = array.as_bytes::(); - if is_lhs_null { - return arr.is_null(rhs_row); - } else if arr.is_null(rhs_row) { - return false; + let array = array.as_bytes::(); + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) + } - let arr = array.as_bytes::(); - let rhs_elem: &[u8] = arr.value(rhs_row).as_ref(); - let rhs_elem_len = arr.value_length(rhs_row).as_usize(); - debug_assert_eq!(rhs_elem_len, rhs_elem.len()); - let l = self.offsets[lhs_row].as_usize(); - let r = self.offsets[lhs_row + 1].as_usize(); - let existing_elem = unsafe { self.buffer.as_slice().get_unchecked(l..r) }; - rhs_elem == existing_elem + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } } } @@ -276,7 +292,7 @@ where } fn size(&self) -> usize { - self.buffer.capacity() * std::mem::size_of::() + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.nulls.allocated_size() } @@ -289,18 +305,7 @@ where nulls, } = *self; - let null_buffer = if nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let num_values = offsets.len() - 1; - let mut bool_builder = BooleanBufferBuilder::new(num_values); - bool_builder.append_n(num_values, true); - nulls.into_iter().for_each(|null_index| { - bool_builder.set_bit(null_index, false); - }); - Some(NullBuffer::from(bool_builder.finish())) - }; + let null_buffer = nulls.build(); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- // monotonically increasing, overflows were checked. @@ -317,9 +322,9 @@ where // SAFETY: // 1. the offsets were constructed safely // - // 2. we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out Arc::new(unsafe { GenericStringArray::new_unchecked(offsets, values, null_buffer) }) @@ -330,27 +335,7 @@ where fn take_n(&mut self, n: usize) -> ArrayRef { debug_assert!(self.len() >= n); - - let null_buffer = if self.nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let mut bool_builder = BooleanBufferBuilder::new(n); - bool_builder.append_n(n, true); - - let mut new_nulls = vec![]; - self.nulls.iter().for_each(|null_index| { - if *null_index < n { - bool_builder.set_bit(*null_index, false); - } else { - new_nulls.push(null_index - n); - } - }); - - self.nulls = new_nulls; - Some(NullBuffer::from(bool_builder.finish())) - }; - + let null_buffer = self.nulls.take_n(n); let first_remaining_offset = O::as_usize(self.offsets[n]); // Given offests like [0, 2, 4, 5] and n = 1, we expect to get @@ -400,13 +385,455 @@ where } } +/// An implementation of [`GroupColumn`] for binary view and utf8 view types. +/// +/// Stores a collection of binary view or utf8 view group values in a buffer +/// whose structure is similar to `GenericByteViewArray`, and we can get benefits: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +/// 3. Efficient to perform `take_n` comparing to use `GenericByteViewBuilder` +pub struct ByteViewGroupValueBuilder { + /// The views of string values + /// + /// If string len <= 12, the view's format will be: + /// string(12B) | len(4B) + /// + /// If string len > 12, its format will be: + /// offset(4B) | buffer_index(4B) | prefix(4B) | len(4B) + views: Vec, + + /// The progressing block + /// + /// New values will be inserted into it until its capacity + /// is not enough(detail can see `max_block_size`). + in_progress: Vec, + + /// The completed blocks + completed: Vec, + + /// The max size of `in_progress` + /// + /// `in_progress` will be flushed into `completed`, and create new `in_progress` + /// when found its remaining capacity(`max_block_size` - `len(in_progress)`), + /// is no enough to store the appended value. + /// + /// Currently it is fixed at 2MB. + max_block_size: usize, + + /// Nulls + nulls: MaybeNullBufferBuilder, + + /// phantom data so the type requires `` + _phantom: PhantomData, +} + +impl ByteViewGroupValueBuilder { + pub fn new() -> Self { + Self { + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, + nulls: MaybeNullBufferBuilder::new(), + _phantom: PhantomData {}, + } + } + + /// Set the max block size + fn with_max_block_size(mut self, max_block_size: usize) -> Self { + self.max_block_size = max_block_size; + self + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteViewType, + { + let arr = array.as_byte_view::(); + + // Null row case, set and return + if arr.is_null(row) { + self.nulls.append(true); + self.views.push(0); + return; + } + + // Not null row case + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + + let value_len = value.len(); + let view = if value_len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure big enough block to hold the value firstly + self.ensure_in_progress_big_enough(value_len); + + // Append value + let buffer_index = self.completed.len(); + let offset = self.in_progress.len(); + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index as u32, offset as u32) + }; + + // Append view + self.views.push(view); + } + + fn ensure_in_progress_big_enough(&mut self, value_len: usize) { + debug_assert!(value_len > 12); + let require_cap = self.in_progress.len() + value_len; + + // If current block isn't big enough, flush it and create a new in progress block + if require_cap > self.max_block_size { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + let array = array.as_byte_view::(); + + // Check if nulls equal firstly + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + + // Otherwise, we need to check their values + let exist_view = self.views[lhs_row]; + let exist_view_len = exist_view as u32; + + let input_view = array.views()[rhs_row]; + let input_view_len = input_view as u32; + + // The check logic + // - Check len equality + // - If inlined, check inlined value + // - If non-inlined, check prefix and then check value in buffer + // when needed + if exist_view_len != input_view_len { + return false; + } + + if exist_view_len <= 12 { + let exist_inline = unsafe { + GenericByteViewArray::::inline_value( + &exist_view, + exist_view_len as usize, + ) + }; + let input_inline = unsafe { + GenericByteViewArray::::inline_value( + &input_view, + input_view_len as usize, + ) + }; + exist_inline == input_inline + } else { + let exist_prefix = + unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; + let input_prefix = + unsafe { GenericByteViewArray::::inline_value(&input_view, 4) }; + + if exist_prefix != input_prefix { + return false; + } + + let exist_full = { + let byte_view = ByteView::from(exist_view); + self.value( + byte_view.buffer_index as usize, + byte_view.offset as usize, + byte_view.length as usize, + ) + }; + let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; + exist_full == input_full + } + } + + fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { + debug_assert!(buffer_index <= self.completed.len()); + + if buffer_index < self.completed.len() { + let block = &self.completed[buffer_index]; + &block[offset..offset + length] + } else { + &self.in_progress[offset..offset + length] + } + } + + fn build_inner(self) -> ArrayRef { + let Self { + views, + in_progress, + mut completed, + nulls, + .. + } = self; + + // Build nulls + let null_buffer = nulls.build(); + + // Build values + // Flush `in_process` firstly + if !in_progress.is_empty() { + let buffer = Buffer::from(in_progress); + completed.push(buffer); + } + + let views = ScalarBuffer::from(views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + completed, + null_buffer, + )) + } + } + + fn take_n_inner(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + + // The `n == len` case, we need to take all + if self.len() == n { + let new_builder = Self::new().with_max_block_size(self.max_block_size); + let cur_builder = replace(self, new_builder); + return cur_builder.build_inner(); + } + + // The `n < len` case + // Take n for nulls + let null_buffer = self.nulls.take_n(n); + + // Take n for values: + // - Take first n `view`s from `views` + // + // - Find the last non-inlined `view`, if all inlined, + // we can build array and return happily, otherwise we + // we need to continue to process related buffers + // + // - Get the last related `buffer index`(let's name it `buffer index n`) + // from last non-inlined `view` + // + // - Take buffers, the key is that we need to know if we need to take + // the whole last related buffer. The logic is a bit complex, you can + // detail in `take_buffers_with_whole_last`, `take_buffers_with_partial_last` + // and other related steps in following + // + // - Shift the `buffer index` of remaining non-inlined `views` + // + let first_n_views = self.views.drain(0..n).collect::>(); + + let last_non_inlined_view = first_n_views + .iter() + .rev() + .find(|view| ((**view) as u32) > 12); + + // All taken views inlined + let Some(view) = last_non_inlined_view else { + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + return Arc::new(GenericByteViewArray::::new_unchecked( + views, + Vec::new(), + null_buffer, + )); + } + }; + + // Unfortunately, some taken views non-inlined + let view = ByteView::from(*view); + let last_remaining_buffer_index = view.buffer_index as usize; + + // Check should we take the whole `last_remaining_buffer_index` buffer + let take_whole_last_buffer = self.should_take_whole_buffer( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ); + + // Take related buffers + let buffers = if take_whole_last_buffer { + self.take_buffers_with_whole_last(last_remaining_buffer_index) + } else { + self.take_buffers_with_partial_last( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ) + }; + + // Shift `buffer index`s finally + let shifts = if take_whole_last_buffer { + last_remaining_buffer_index + 1 + } else { + last_remaining_buffer_index + }; + + self.views.iter_mut().for_each(|view| { + if (*view as u32) > 12 { + let mut byte_view = ByteView::from(*view); + byte_view.buffer_index -= shifts as u32; + *view = byte_view.as_u128(); + } + }); + + // Build array and return + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + buffers, + null_buffer, + )) + } + } + + fn take_buffers_with_whole_last( + &mut self, + last_remaining_buffer_index: usize, + ) -> Vec { + if last_remaining_buffer_index == self.completed.len() { + self.flush_in_progress(); + } + self.completed + .drain(0..last_remaining_buffer_index + 1) + .collect() + } + + fn take_buffers_with_partial_last( + &mut self, + last_remaining_buffer_index: usize, + last_take_len: usize, + ) -> Vec { + let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); + + // Take `0 ~ last_remaining_buffer_index - 1` buffers + if !self.completed.is_empty() || last_remaining_buffer_index == 0 { + take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); + } + + // Process the `last_remaining_buffer_index` buffers + let last_buffer = if last_remaining_buffer_index < self.completed.len() { + // If it is in `completed`, simply clone + self.completed[last_remaining_buffer_index].clone() + } else { + // If it is `in_progress`, copied `0 ~ offset` part + let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); + Buffer::from_vec(taken_last_buffer) + }; + take_buffers.push(last_buffer); + + take_buffers + } + + #[inline] + fn should_take_whole_buffer(&self, buffer_index: usize, take_len: usize) -> bool { + if buffer_index < self.completed.len() { + take_len == self.completed[buffer_index].len() + } else { + take_len == self.in_progress.len() + } + } + + fn flush_in_progress(&mut self) { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } +} + +impl GroupColumn for ByteViewGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + self.equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + self.append_val_inner(array, row) + } + + fn len(&self) -> usize { + self.views.len() + } + + fn size(&self) -> usize { + let buffers_size = self + .completed + .iter() + .map(|buf| buf.capacity() * size_of::()) + .sum::(); + + self.nulls.allocated_size() + + self.views.capacity() * size_of::() + + self.in_progress.capacity() * size_of::() + + buffers_size + + size_of::() + } + + fn build(self: Box) -> ArrayRef { + Self::build_inner(*self) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.take_n_inner(n) + } +} + +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow_array::{ArrayRef, StringArray}; + use arrow::{ + array::AsArray, + datatypes::{Int64Type, StringViewType}, + }; + use arrow_array::{ArrayRef, Int64Array, StringArray, StringViewArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; use datafusion_physical_expr::binary_map::OutputType; + use crate::aggregates::group_values::group_column::{ + ByteViewGroupValueBuilder, PrimitiveGroupValueBuilder, + }; + use super::{ByteGroupValueBuilder, GroupColumn}; #[test] @@ -453,4 +880,378 @@ mod tests { ])) as ArrayRef; assert_eq!(&output, &array); } + + #[test] + fn test_nullable_primitive_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (_nulls, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some(2) to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + assert!(builder.equal_to(0, &input_array, 0)); + assert!(!builder.equal_to(1, &input_array, 1)); + } + + #[test] + fn test_byte_array_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_byte_view_append_val() { + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = StringViewArray::from(vec![ + Some("this string is quite long"), // in buffer 0 + Some("foo"), + None, + Some("bar"), + Some("this string is also quite long"), // buffer 0 + Some("this string is quite long"), // buffer 1 + Some("bar"), + ]); + let builder_array: ArrayRef = Arc::new(builder_array); + for row in 0..builder_array.len() { + builder.append_val(&builder_array, row); + } + + let output = Box::new(builder).build(); + // should be 2 output buffers to hold all the data + assert_eq!(output.as_string_view().data_buffers().len(), 2,); + assert_eq!(&output, &builder_array) + } + + #[test] + fn test_byte_view_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; value lens not equal + // - exist not null, input not null; value not equal(inlined case) + // - exist not null, input not null; value equal(inlined case) + // + // - exist not null, input not null; value not equal + // (non-inlined case + prefix not equal) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `in_progress`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `in_progress`) + + // Set the block size to 40 for ensuring some unlined values are in `in_progress`, + // and some are in `completed`, so both two branches in `value` function can be covered. + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = Arc::new(StringViewArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bazz"), + Some("foo"), + Some("bar"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in progress"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + builder.append_val(&builder_array, 6); + builder.append_val(&builder_array, 7); + builder.append_val(&builder_array, 8); + + // Define input array + let (views, buffer, _nulls) = StringViewArray::from(vec![ + Some("foo"), + Some("bar"), // set to null + None, + None, + Some("baz"), + Some("oof"), + Some("bar"), + Some("i am a long string for test eq in completed"), + Some("I am a long string for test eq in COMPLETED"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in PROGRESS"), + Some("I am a long string for test eq in progress"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(9); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(!builder.equal_to(5, &input_array, 5)); + assert!(builder.equal_to(6, &input_array, 6)); + assert!(!builder.equal_to(7, &input_array, 7)); + assert!(!builder.equal_to(7, &input_array, 8)); + assert!(builder.equal_to(7, &input_array, 9)); + assert!(!builder.equal_to(8, &input_array, 10)); + assert!(builder.equal_to(8, &input_array, 11)); + } + + #[test] + fn test_byte_view_take_n() { + // ####### Define cases and init ####### + + // `take_n` is really complex, we should consider and test following situations: + // 1. Take nulls + // 2. Take all `inlined`s + // 3. Take non-inlined + partial last buffer in `completed` + // 4. Take non-inlined + whole last buffer in `completed` + // 5. Take non-inlined + partial last `in_progress` + // 6. Take non-inlined + whole last buffer in `in_progress` + // 7. Take all views at once + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let input_array = StringViewArray::from(vec![ + // Test situation 1 + None, + None, + // Test situation 2 (also test take null together) + None, + Some("foo"), + Some("bar"), + // Test situation 3 (also test take null + inlined) + None, + Some("foo"), + Some("this string is quite long"), + Some("this string is also quite long"), + // Test situation 4 (also test take null + inlined) + None, + Some("bar"), + Some("this string is quite long"), + // Test situation 5 (also test take null + inlined) + None, + Some("foo"), + Some("another string that is is quite long"), + Some("this string not so long"), + // Test situation 6 (also test take null + inlined + insert again after taking) + None, + Some("bar"), + Some("this string is quite long"), + // Insert 4 and just take 3 to ensure it will go the path of situation 6 + None, + // Finally, we create a new builder, insert the whole array and then + // take whole at once for testing situation 7 + ]); + + let input_array: ArrayRef = Arc::new(input_array); + let first_ones_to_append = 16; // For testing situation 1~5 + let second_ones_to_append = 4; // For testing situation 6 + let final_ones_to_append = input_array.len(); // For testing situation 7 + + // ####### Test situation 1~5 ####### + for row in 0..first_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 2); + assert_eq!(builder.in_progress.len(), 59); + + // Situation 1 + let taken_array = builder.take_n(2); + assert_eq!(&taken_array, &input_array.slice(0, 2)); + + // Situation 2 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(2, 3)); + + // Situation 3 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(5, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(8, 1)); + + // Situation 4 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(9, 3)); + + // Situation 5 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(12, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(15, 1)); + + // ####### Test situation 6 ####### + assert!(builder.completed.is_empty()); + assert!(builder.in_progress.is_empty()); + assert!(builder.views.is_empty()); + + for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { + builder.append_val(&input_array, row); + } + + assert!(builder.completed.is_empty()); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(16, 3)); + + // ####### Test situation 7 ####### + // Create a new builder + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + for row in 0..final_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 3); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(final_ones_to_append); + assert_eq!(&taken_array, &input_array); + } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 9256631fa578d..fb7b667750924 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`GroupValues`] trait for storing and interning group keys + use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; @@ -36,19 +38,63 @@ use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; mod group_column; - -/// An interning store for group keys +mod null_builder; + +/// Stores the group values during hash aggregation. +/// +/// # Background +/// +/// In a query such as `SELECT a, b, count(*) FROM t GROUP BY a, b`, the group values +/// identify each group, and correspond to all the distinct values of `(a,b)`. +/// +/// ```sql +/// -- Input has 4 rows with 3 distinct combinations of (a,b) ("groups") +/// create table t(a int, b varchar) +/// as values (1, 'a'), (2, 'b'), (1, 'a'), (3, 'c'); +/// +/// select a, b, count(*) from t group by a, b; +/// ---- +/// 1 a 2 +/// 2 b 1 +/// 3 c 1 +/// ``` +/// +/// # Design +/// +/// Managing group values is a performance critical operation in hash +/// aggregation. The major operations are: +/// +/// 1. Intern: Quickly finding existing and adding new group values +/// 2. Emit: Returning the group values as an array +/// +/// There are multiple specialized implementations of this trait optimized for +/// different data types and number of columns, optimized for these operations. +/// See [`new_group_values`] for details. +/// +/// # Group Ids +/// +/// Each distinct group in a hash aggregation is identified by a unique group id +/// (usize) which is assigned by instances of this trait. Group ids are +/// continuous without gaps, starting from 0. pub trait GroupValues: Send { - /// Calculates the `groups` for each input row of `cols` + /// Calculates the group id for each input row of `cols`, assigning new + /// group ids as necessary. + /// + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`. + /// + /// If a row has the same value as a previous row, the same group id is + /// assigned. If a row has a new value, the next available group id is + /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - /// Returns the number of bytes used by this [`GroupValues`] + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; /// Returns true if this [`GroupValues`] is empty fn is_empty(&self) -> bool; - /// The number of values stored in this [`GroupValues`] + /// The number of values (distinct group values) stored in this [`GroupValues`] fn len(&self) -> usize; /// Emits the group values @@ -58,6 +104,7 @@ pub trait GroupValues: Send { fn clear_shrink(&mut self, batch: &RecordBatch); } +/// Return a specialized implementation of [`GroupValues`] for the given schema. pub fn new_group_values(schema: SchemaRef) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs new file mode 100644 index 0000000000000..0249390f38cdd --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + +/// Builder for an (optional) null mask +/// +/// Optimized for avoid creating the bitmask when all values are non-null +#[derive(Debug)] +pub(crate) enum MaybeNullBufferBuilder { + /// seen `row_count` rows but no nulls yet + NoNulls { row_count: usize }, + /// have at least one null value + /// + /// Note this is an Arrow *VALIDITY* buffer (so it is false for nulls, true + /// for non-nulls) + Nulls(BooleanBufferBuilder), +} + +impl MaybeNullBufferBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::NoNulls { row_count: 0 } + } + + /// Return true if the row at index `row` is null + pub fn is_null(&self, row: usize) -> bool { + match self { + Self::NoNulls { .. } => false, + // validity mask means a unset bit is NULL + Self::Nulls(builder) => !builder.get_bit(row), + } + } + + /// Set the nullness of the next row to `is_null` + /// + /// num_values is the current length of the rows being tracked + /// + /// If `value` is true, the row is null. + /// If `value` is false, the row is non null + pub fn append(&mut self, is_null: bool) { + match self { + Self::NoNulls { row_count } if is_null => { + // have seen no nulls so far, this is the first null, + // need to create the nulls buffer for all currently valid values + // alloc 2x the need given we push a new but immediately + let mut nulls = BooleanBufferBuilder::new(*row_count * 2); + nulls.append_n(*row_count, true); + nulls.append(false); + *self = Self::Nulls(nulls); + } + Self::NoNulls { row_count } => { + *row_count += 1; + } + Self::Nulls(builder) => builder.append(!is_null), + } + } + + /// return the number of heap allocated bytes used by this structure to store boolean values + pub fn allocated_size(&self) -> usize { + match self { + Self::NoNulls { .. } => 0, + // BooleanBufferBuilder builder::capacity returns capacity in bits (not bytes) + Self::Nulls(builder) => builder.capacity() / 8, + } + } + + /// Return a NullBuffer representing the accumulated nulls so far + pub fn build(self) -> Option { + match self { + Self::NoNulls { .. } => None, + Self::Nulls(mut builder) => Some(NullBuffer::from(builder.finish())), + } + } + + /// Returns a NullBuffer representing the first `n` rows accumulated so far + /// shifting any remaining down by `n` + pub fn take_n(&mut self, n: usize) -> Option { + match self { + Self::NoNulls { row_count } => { + *row_count -= n; + None + } + Self::Nulls(builder) => { + // Copy over the values at n..len-1 values to the start of a + // new builder and leave it in self + // + // TODO: it would be great to use something like `set_bits` from arrow here. + let mut new_builder = BooleanBufferBuilder::new(builder.len()); + for i in n..builder.len() { + new_builder.append(builder.get_bit(i)); + } + std::mem::swap(&mut new_builder, builder); + + // take only first n values from the original builder + new_builder.truncate(n); + Some(NullBuffer::from(new_builder.finish())) + } + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index d5b7f1b11ac55..05214ec10d68b 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -30,6 +30,7 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::raw::RawTable; +use std::mem::size_of; use std::sync::Arc; /// A trait to allow hashing of floating point numbers @@ -151,7 +152,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::() + self.values.allocated_size() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index b252d0008784f..de0ae2e07dd29 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -27,9 +27,17 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use std::mem::size_of; use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] +/// +/// This is a general implementation of [`GroupValues`] that works for any +/// combination of data types and number of columns, including nested types such as +/// structs and lists. +/// +/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise +/// representation. pub struct GroupValuesRows { /// The output schema schema: SchemaRef, @@ -220,13 +228,12 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) + // TODO: Materialize dictionaries in group keys + // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); - *array = dictionary_encode_if_necessary( - Arc::::clone(array), - expected, - )?; + *array = + dictionary_encode_if_necessary(Arc::::clone(array), expected)?; } self.group_values = Some(group_values); @@ -241,7 +248,7 @@ impl GroupValues for GroupValuesRows { }); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } @@ -259,7 +266,7 @@ fn dictionary_encode_if_necessary( .zip(struct_array.columns()) .map(|(expected_field, column)| { dictionary_encode_if_necessary( - Arc::::clone(column), + Arc::::clone(column), expected_field.data_type(), ) }) @@ -278,13 +285,13 @@ fn dictionary_encode_if_necessary( Arc::::clone(expected_field), list.offsets().clone(), dictionary_encode_if_necessary( - Arc::::clone(list.values()), + Arc::::clone(list.values()), expected_field.data_type(), )?, list.nulls().cloned(), )?)) } (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), - (_, _) => Ok(Arc::::clone(&array)), + (_, _) => Ok(Arc::::clone(&array)), } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2bdaed4796553..48a03af19dbd5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -26,6 +26,7 @@ use crate::aggregates::{ topk_stream::GroupedTopKAggregateStream, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -35,10 +36,11 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, expressions::Column, @@ -46,6 +48,7 @@ use datafusion_physical_expr::{ PhysicalExpr, PhysicalSortRequirement, }; +use crate::execution_plan::CardinalityEffect; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use itertools::Itertools; @@ -210,13 +213,99 @@ impl PhysicalGroupBy { .collect() } + /// The number of expressions in the output schema. + fn num_output_exprs(&self) -> usize { + let mut num_exprs = self.expr.len(); + if !self.is_single() { + num_exprs += 1 + } + num_exprs + } + /// Return grouping expressions as they occur in the output schema. pub fn output_exprs(&self) -> Vec> { - self.expr - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + let num_output_exprs = self.num_output_exprs(); + let mut output_exprs = Vec::with_capacity(num_output_exprs); + output_exprs.extend( + self.expr + .iter() + .enumerate() + .take(num_output_exprs) + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), + ); + if !self.is_single() { + output_exprs.push(Arc::new(Column::new( + Aggregate::INTERNAL_GROUPING_ID, + self.expr.len(), + )) as _); + } + output_exprs + } + + /// Returns the number expression as grouping keys. + fn num_group_exprs(&self) -> usize { + if self.is_single() { + self.expr.len() + } else { + self.expr.len() + 1 + } + } + + /// Returns the fields that are used as the grouping keys. + fn group_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = Vec::with_capacity(self.num_group_exprs()); + for ((expr, name), group_expr_nullable) in + self.expr.iter().zip(self.exprs_nullable().into_iter()) + { + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + ) + .with_metadata( + get_field_metadata(expr, input_schema).unwrap_or_default(), + ), + ); + } + if !self.is_single() { + fields.push(Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + )); + } + Ok(fields) + } + + /// Returns the output fields of the group by. + /// + /// This might be different from the `group_fields` that might contain internal expressions that + /// should not be part of the output schema. + fn output_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = self.group_fields(input_schema)?; + fields.truncate(self.num_output_exprs()); + Ok(fields) + } + + /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial + /// aggregation. + pub fn as_final(&self) -> PhysicalGroupBy { + let expr: Vec<_> = + self.output_exprs() + .into_iter() + .zip( + self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( + Aggregate::INTERNAL_GROUPING_ID.to_owned(), + )), + ) + .collect(); + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + } } } @@ -262,7 +351,7 @@ pub struct AggregateExec { /// Group by expressions group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause @@ -289,7 +378,10 @@ impl AggregateExec { /// Function used in `OptimizeAggregateOrder` optimizer rule, /// where we need parts of the new value, others cloned from the old one /// Rewrites aggregate exec with new aggregate expressions. - pub fn with_new_aggr_exprs(&self, aggr_expr: Vec) -> Self { + pub fn with_new_aggr_exprs( + &self, + aggr_expr: Vec>, + ) -> Self { Self { aggr_expr, // clone the rest of the fields @@ -315,18 +407,12 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( - &input.schema(), - &group_by.expr, - &aggr_expr, - group_by.exprs_nullable(), - mode, - )?; + let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); AggregateExec::try_new_with_schema( @@ -352,7 +438,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -462,7 +548,7 @@ impl AggregateExec { } /// Aggregate expressions - pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] { + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -784,26 +870,20 @@ impl ExecutionPlan for AggregateExec { } } } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } fn create_schema( input_schema: &Schema, - group_expr: &[(Arc, String)], - aggr_expr: &[AggregateFunctionExpr], - group_expr_nullable: Vec, + group_by: &PhysicalGroupBy, + aggr_expr: &[Arc], mode: AggregateMode, ) -> Result { - let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - )) - } + let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -823,12 +903,14 @@ fn create_schema( } } - Ok(Schema::new(fields)) + Ok(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )) } -fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { - let group_fields = schema.fields()[0..group_count].to_vec(); - Arc::new(Schema::new(group_fields)) +fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { + Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) } /// Determines the lexical ordering requirement for an aggregate expression. @@ -927,7 +1009,7 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( - aggr_exprs: &mut [AggregateFunctionExpr], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, @@ -955,7 +1037,7 @@ pub fn get_finer_aggregate_exprs_requirement( // Reverse requirement is satisfied by exiting ordering. // Hence reverse the aggregator requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -979,7 +1061,7 @@ pub fn get_finer_aggregate_exprs_requirement( // There is a requirement that both satisfies existing requirement and reverse // aggregate requirement. Use updated requirement requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1001,7 +1083,7 @@ pub fn get_finer_aggregate_exprs_requirement( /// * Partial: AggregateFunctionExpr::expressions /// * Final: columns of `AggregateFunctionExpr::state_fields()` pub fn aggregate_expressions( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { @@ -1056,7 +1138,7 @@ fn merge_expressions( pub type AccumulatorItem = Box; pub fn create_accumulators( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], ) -> Result> { aggr_expr .iter() @@ -1135,6 +1217,27 @@ fn evaluate_optional( .collect() } +fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { + if group.len() > 64 { + return not_impl_err!( + "Grouping sets with more than 64 columns are not supported" + ); + } + let group_id = group.iter().fold(0u64, |acc, &is_null| { + (acc << 1) | if is_null { 1 } else { 0 } + }); + let num_rows = batch.num_rows(); + if group.len() <= 8 { + Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) + } else if group.len() <= 16 { + Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) + } else if group.len() <= 32 { + Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + } else { + Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + } +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: @@ -1167,23 +1270,24 @@ pub(crate) fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by + group_by .groups .iter() .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - Arc::clone(&null_exprs[idx]) - } else { - Arc::clone(&exprs[idx]) - } - }) - .collect() + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); + group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { + if *is_null { + Arc::clone(&null_exprs[idx]) + } else { + Arc::clone(&exprs[idx]) + } + })); + if !group_by.is_single() { + group_values.push(group_id_array(group, batch)?); + } + Ok(group_values) }) - .collect()) + .collect() } #[cfg(test)] @@ -1341,26 +1445,28 @@ mod tests { ) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::UInt32(None)), "a".to_string()), (lit(ScalarValue::Float64(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![true, false], // (NULL, b) vec![false, false], // (a,b) ], - }; + ); - let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) - .alias("COUNT(1)") - .build()?]; + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(1)") + .build()?, + )]; let task_ctx = if spill { // adjust the max memory size to have the partial aggregate result for spill mode. @@ -1379,69 +1485,62 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { // In spill mode, we test with the limited memory, if the mem usage exceeds, // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 1 |", - "| | 1.0 | 1 |", - "| | 2.0 | 1 |", - "| | 2.0 | 1 |", - "| | 3.0 | 1 |", - "| | 3.0 | 1 |", - "| | 4.0 | 1 |", - "| | 4.0 | 1 |", - "| 2 | | 1 |", - "| 2 | | 1 |", - "| 2 | 1.0 | 1 |", - "| 2 | 1.0 | 1 |", - "| 3 | | 1 |", - "| 3 | | 2 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 1 |", - "| 4 | | 2 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] } else { vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); - let groups = partial_aggregate.group_expr().expr().to_vec(); - let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = groups - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let task_ctx = if spill { new_spill_ctx(4, 3160) @@ -1458,29 +1557,28 @@ mod tests { input_schema, )?); - let result = - common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; - assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); let expected = vec![ - "+---+-----+----------+", - "| a | b | COUNT(1) |", - "+---+-----+----------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+----------+", + "+---+-----+---------------+----------+", + "| a | b | __grouping_id | COUNT(1) |", + "+---+-----+---------------+----------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+----------+", ]; assert_batches_sorted_eq!(&expected, &result); @@ -1496,19 +1594,18 @@ mod tests { async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1527,7 +1624,7 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1556,13 +1653,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = grouping_set - .expr - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, @@ -1579,7 +1670,7 @@ mod tests { } else { Arc::clone(&task_ctx) }; - let result = common::collect(merged_aggregate.execute(0, task_ctx)?).await?; + let result = collect(merged_aggregate.execute(0, task_ctx)?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); @@ -1598,12 +1689,24 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + if spill { // When spilling, the output rows metrics become partial output size + final output size // This is because final aggregation starts while partial aggregation is still emitting assert_eq!(8, output_rows); + + assert!(spill_count > 0); + assert!(spilled_bytes > 0); + assert!(spilled_rows > 0); } else { assert_eq!(3, output_rows); + + assert_eq!(0, spill_count); + assert_eq!(0, spilled_bytes); + assert_eq!(0, spilled_rows); } Ok(()) @@ -1818,24 +1921,23 @@ mod tests { let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); - let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let groups_some = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); // something that allocates within the aggregator - let aggregates_v0: Vec = - vec![test_median_agg_expr(Arc::clone(&input_schema))?]; + let aggregates_v0: Vec> = + vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates_v2: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1868,7 +1970,7 @@ mod tests { } let stream: SendableRecordBatchStream = stream.into(); - let err = common::collect(stream).await.unwrap_err(); + let err = collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. let err = err.find_root(); @@ -1889,13 +1991,12 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(a)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1929,13 +2030,12 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1980,7 +2080,7 @@ mod tests { fn test_first_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -1992,13 +2092,14 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2009,6 +2110,7 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // This function either constructs the physical plan below, @@ -2053,7 +2155,7 @@ mod tests { descending: false, nulls_first: false, }; - let aggregates: Vec = if is_first_acc { + let aggregates: Vec> = if is_first_acc { vec![test_first_value_agg_expr(&schema, sort_options)?] } else { vec![test_last_value_agg_expr(&schema, sort_options)?] @@ -2189,6 +2291,7 @@ mod tests { .order_by(ordering_req.to_vec()) .schema(Arc::clone(&test_schema)) .build() + .map(Arc::new) .unwrap() }) .collect::>(); @@ -2218,7 +2321,7 @@ mod tests { }; let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); - let aggregates: Vec = vec![ + let aggregates: Vec> = vec![ test_first_value_agg_expr(&schema, option_desc)?, test_last_value_agg_expr(&schema, option_desc)?, ]; @@ -2276,11 +2379,12 @@ mod tests { ], ); - let aggregates: Vec = + let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) .schema(Arc::clone(&schema)) .alias("1") - .build()?]; + .build() + .map(Arc::new)?]; let input_batches = (0..4) .map(|_| { @@ -2299,7 +2403,7 @@ mod tests { )?); let aggregate_exec = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups, aggregates.clone(), vec![None], @@ -2311,13 +2415,13 @@ mod tests { collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; let expected = [ - "+-----+-----+-------+----------+", - "| a | b | const | 1[count] |", - "+-----+-----+-------+----------+", - "| | 0.0 | | 32768 |", - "| 0.0 | | | 32768 |", - "| | | 1 | 32768 |", - "+-----+-----+-------+----------+", + "+-----+-----+-------+---------------+-------+", + "| a | b | const | __grouping_id | 1 |", + "+-----+-----+-------+---------------+-------+", + "| | | 1 | 6 | 32768 |", + "| | 0.0 | | 5 | 32768 |", + "| 0.0 | | | 3 | 32768 |", + "+-----+-----+-------+---------------+-------+", ]; assert_batches_sorted_eq!(expected, &output); @@ -2412,11 +2516,12 @@ mod tests { ) .schema(Arc::clone(&batch.schema())) .alias(String::from("SUM(value)")) - .build()?]; + .build() + .map(Arc::new)?]; let input = Arc::new(MemoryExec::try_new( &[vec![batch.clone()]], - Arc::::clone(&batch.schema()), + Arc::::clone(&batch.schema()), None, )?); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2460,7 +2565,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2541,7 +2647,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2628,33 +2735,34 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) .schema(Arc::clone(&input_schema)) .alias("COUNT(a)") - .build()?, + .build() + .map(Arc::new)?, ]; - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::Float32(None)), "a".to_string()), (lit(ScalarValue::Float32(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![false, false], // (a,b) ], - }; + ); let aggr_schema = create_schema( &input_schema, - &grouping_set.expr, + &grouping_set, &aggr_expr, - grouping_set.exprs_nullable(), AggregateMode::Final, )?; let expected_schema = Schema::new(vec![ Field::new("a", DataType::Float32, false), Field::new("b", DataType::Float32, true), + Field::new("__grouping_id", DataType::UInt8, false), Field::new("COUNT(a)", DataType::Int64, false), ]); assert_eq!(aggr_schema, expected_schema); diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index d64c99ba1bee3..218855459b1e2 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -16,6 +16,7 @@ // under the License. use datafusion_expr::EmitTo; +use std::mem::size_of; /// Tracks grouping state when the data is ordered entirely by its /// group keys @@ -139,7 +140,7 @@ impl GroupOrderingFull { } pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 483150ee61af6..accb2fda11316 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -20,6 +20,7 @@ use arrow_schema::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; +use std::mem::size_of; mod full; mod partial; @@ -118,7 +119,7 @@ impl GroupOrdering { /// Return the size of memory used by the ordering state, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + match self { GroupOrdering::None => 0, GroupOrdering::Partial(partial) => partial.size(), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index 2cbe3bbb784ec..2dd1ea8a5449e 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -22,6 +22,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; +use std::mem::size_of; use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of @@ -244,7 +245,7 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.order_indices.allocated_size() + self.row_converter.size() } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d4dbdf0f029d4..7d21cc2f1944b 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -29,7 +29,7 @@ use crate::aggregates::{ }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; -use crate::sorts::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::{read_spill_as_stream, spill_record_batch_by_size}; use crate::stream::RecordBatchStreamAdapter; use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; @@ -38,7 +38,7 @@ use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; -use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -102,6 +102,19 @@ struct SpillState { /// true when streaming merge is in progress is_stream_merging: bool, + + // ======================================================================== + // METRICS: + // ======================================================================== + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, + /// count of spill files during the execution of the operator + spill_count: metrics::Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: metrics::Count, + /// total spilled rows during the execution of the operator + spilled_rows: metrics::Count, } /// Tracks if the aggregate should skip partial aggregations @@ -138,6 +151,9 @@ struct SkipAggregationProbe { /// make any effect (set either while probing or on probing completion) is_locked: bool, + // ======================================================================== + // METRICS: + // ======================================================================== /// Number of rows where state was output without aggregation. /// /// * If 0, all input rows were aggregated (should_skip was always false) @@ -449,13 +465,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; let filter_expressions = match agg.mode { @@ -473,7 +489,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; let spill_expr = group_schema .fields .into_iter() @@ -510,6 +526,11 @@ impl GroupedHashAggregateStream { is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + peak_mem_used: MetricBuilder::new(&agg.metrics) + .gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition), }; // Skip aggregation is supported if: @@ -570,7 +591,7 @@ impl GroupedHashAggregateStream { /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( - agg_expr: &AggregateFunctionExpr, + agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() @@ -580,7 +601,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } @@ -609,14 +630,11 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + // New batch to aggregate in partial aggregation operator + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,10 +667,49 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } + + // Found end from input stream None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); @@ -691,7 +748,12 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done - } else if self.should_skip_aggregation() { + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -824,11 +886,19 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - self.reservation.try_resize( + let reservation_result = self.reservation.try_resize( acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(), - ) + ); + + if reservation_result.is_ok() { + self.spill_state + .peak_mem_used + .set_max(self.reservation.size()); + } + + reservation_result } /// Create an output RecordBatch with the group keys and @@ -879,10 +949,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -905,6 +975,14 @@ impl GroupedHashAggregateStream { self.batch_size, )?; self.spill_state.spills.push(spillfile); + + // Update metrics + self.spill_state.spill_count.add(1); + self.spill_state + .spilled_bytes + .add(sorted.get_array_memory_size()); + self.spill_state.spilled_rows.add(sorted.num_rows()); + Ok(()) } @@ -927,9 +1005,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); @@ -960,15 +1038,14 @@ impl GroupedHashAggregateStream { streams.push(stream); } self.spill_state.is_stream_merging = true; - self.input = streaming_merge( - streams, - schema, - &self.spill_state.spill_expr, - self.baseline_metrics.clone(), - self.batch_size, - None, - self.reservation.new_empty(), - )?; + self.input = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(schema) + .with_expressions(&self.spill_state.spill_expr) + .with_metrics(self.baseline_metrics.clone()) + .with_batch_size(self.batch_size) + .with_reservation(self.reservation.new_empty()) + .build()?; self.input_done = false; self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); Ok(()) @@ -1002,6 +1079,8 @@ impl GroupedHashAggregateStream { } /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation fn update_skip_aggregation_probe(&mut self, input_rows: usize) { if let Some(probe) = self.skip_aggregation_probe.as_mut() { // Skip aggregation probe is not supported if stream has any spills, @@ -1013,6 +1092,8 @@ impl GroupedHashAggregateStream { /// In case the probe indicates that aggregation may be /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { @@ -1026,6 +1107,8 @@ impl GroupedHashAggregateStream { /// Returns true if the aggregation probe indicates that aggregation /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation fn should_skip_aggregation(&self) -> bool { self.skip_aggregation_probe .as_ref() @@ -1034,13 +1117,14 @@ impl GroupedHashAggregateStream { /// Transforms input batch to intermediate aggregate state, without grouping it fn transform_to_states(&self, batch: RecordBatch) -> Result { - let group_values = evaluate_group_by(&self.group_by, &batch)?; + let mut group_values = evaluate_group_by(&self.group_by, &batch)?; let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; - let mut output = group_values.first().cloned().ok_or_else(|| { - internal_datafusion_err!("group_values expected to have at least one element") - })?; + if group_values.len() != 1 { + return internal_err!("group_values expected to have single element"); + } + let mut output = group_values.swap_remove(0); let iter = self .accumulators diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 232b87de32314..34df643b6cf0c 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -109,7 +109,7 @@ impl StringHashTable { Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } @@ -181,7 +181,7 @@ where Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 287446328f8de..c8b329fabdaab 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -40,9 +40,9 @@ use futures::StreamExt; /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] pub struct AnalyzeExec { - /// control how much extra to print + /// Control how much extra to print verbose: bool, - /// if statistics should be displayed + /// If statistics should be displayed show_statistics: bool, /// The input plan (the plan being analyzed) pub(crate) input: Arc, @@ -69,12 +69,12 @@ impl AnalyzeExec { } } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } - /// access to show_statistics + /// Access to show_statistics pub fn show_statistics(&self) -> bool { self.show_statistics } diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 7caf5b8ab65a3..61fb3599f0131 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -34,6 +34,7 @@ use datafusion_common::Result; use datafusion_execution::TaskContext; use crate::coalesce::{BatchCoalescer, CoalescerState}; +use crate::execution_plan::CardinalityEffect; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -51,7 +52,7 @@ use futures::stream::{Stream, StreamExt}; pub struct CoalesceBatchesExec { /// The input plan input: Arc, - /// Minimum number of rows for coalesces batches + /// Minimum number of rows for coalescing batches target_batch_size: usize, /// Maximum number of rows to fetch, `None` means fetching all rows fetch: Option, @@ -199,6 +200,10 @@ impl ExecutionPlan for CoalesceBatchesExec { fn fetch(&self) -> Option { self.fetch } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 486ae41901db3..f9d4ec6a1a349 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -30,6 +30,7 @@ use super::{ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::execution_plan::CardinalityEffect; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; @@ -178,6 +179,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } #[cfg(test)] @@ -231,10 +236,10 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); - let coaelesce_partitions_exec = + let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, task_ctx); + let fut = collect(coalesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 4b5eea6b760df..844208999d254 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -109,7 +109,7 @@ pub(crate) fn spawn_buffered( builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { - // receiver dropped when query is shutdown early (e.g., limit) or error, + // Receiver dropped when query is shutdown early (e.g., limit) or error, // no need to return propagate the send error. return Ok(()); } @@ -156,7 +156,11 @@ pub fn compute_record_batch_statistics( for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - null_counts[stat_index] += batch.column(*col_index).null_count(); + null_counts[stat_index] += batch + .column(*col_index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default(); } } } @@ -178,15 +182,15 @@ pub fn compute_record_batch_statistics( /// Write in Arrow IPC format. pub struct IPCWriter { - /// path + /// Path pub path: PathBuf, - /// inner writer + /// Inner writer pub writer: FileWriter, - /// batches written + /// Batches written pub num_batches: usize, - /// rows written + /// Rows written pub num_rows: usize, - /// bytes written + /// Bytes written pub num_bytes: usize, } @@ -311,7 +315,7 @@ mod tests { ], )?; - // just select f32,f64 + // Just select f32,f64 let select_projection = Some(vec![0, 1]); let byte_size = batch .project(&select_projection.clone().unwrap()) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 0d2653c5c7753..e79b3c817bd1b 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -125,7 +125,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { t: self.format_type, f, @@ -164,7 +164,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let t = DisplayFormatType::Default; let mut visitor = GraphvizVisitor { @@ -203,7 +203,7 @@ impl<'a> DisplayableExecutionPlan<'a> { } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { f, t: DisplayFormatType::Default, @@ -231,6 +231,7 @@ impl<'a> DisplayableExecutionPlan<'a> { } } +/// Enum representing the different levels of metrics to display #[derive(Debug, Clone, Copy)] enum ShowMetrics { /// Do not show any metrics @@ -256,7 +257,7 @@ struct IndentVisitor<'a, 'b> { /// How to format each node t: DisplayFormatType, /// Write to this formatter - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// Indent size indent: usize, /// How to show metrics @@ -317,7 +318,7 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } struct GraphvizVisitor<'a, 'b> { - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// How to format each node t: DisplayFormatType, /// How to show metrics @@ -348,8 +349,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { struct Wrapper<'a>(&'a dyn ExecutionPlan, DisplayFormatType); - impl<'a> std::fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(self.1, f) } } @@ -421,14 +422,14 @@ pub trait DisplayAs { /// different from the default one /// /// Should not include a newline - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result; + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; } /// A newtype wrapper to display `T` implementing`DisplayAs` using the `Default` mode pub struct DefaultDisplay(pub T); impl fmt::Display for DefaultDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Default, f) } } @@ -437,7 +438,7 @@ impl fmt::Display for DefaultDisplay { pub struct VerboseDisplay(pub T); impl fmt::Display for VerboseDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Verbose, f) } } @@ -447,7 +448,7 @@ impl fmt::Display for VerboseDisplay { pub struct ProjectSchemaDisplay<'a>(pub &'a SchemaRef); impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let parts: Vec<_> = self .0 .fields() @@ -463,7 +464,7 @@ impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { pub struct OutputOrderingDisplay<'a>(pub &'a [PhysicalSortExpr]); impl<'a> fmt::Display for OutputOrderingDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "[")?; for (i, e) in self.0.iter().enumerate() { if i > 0 { diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 4bacea48c3473..f6e0abb94fa88 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -173,7 +173,7 @@ mod tests { let empty = EmptyExec::new(Arc::clone(&schema)); assert_eq!(empty.schema(), schema); - // we should have no results + // We should have no results let iter = empty.execute(0, task_ctx)?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 542861688dfe1..e6484452d43e5 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use futures::stream::{StreamExt, TryStreamExt}; use tokio::task::JoinSet; @@ -228,6 +229,16 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// [`TryStreamExt`]: futures::stream::TryStreamExt /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter /// + /// # Error handling + /// + /// Any error that occurs during execution is sent as an `Err` in the output + /// stream. + /// + /// `ExecutionPlan` implementations in DataFusion cancel additional work + /// immediately once an error occurs. The rationale is that if the overall + /// query will return an error, any additional work such as continued + /// polling of inputs will be wasted as it will be thrown away. + /// /// # Cancellation / Aborting Execution /// /// The [`Stream`] that is returned must ensure that any allocated resources @@ -406,6 +417,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { fn fetch(&self) -> Option { None } + + /// Gets the effect on cardinality, if known + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Unknown + } } /// Extension trait provides an easy API to fetch various properties of @@ -837,7 +853,7 @@ pub fn execute_input_stream( Ok(Box::pin(RecordBatchStreamAdapter::new( sink_schema, input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + .map(move |batch| check_not_null_constraints(batch?, &risky_columns)), ))) } } @@ -857,7 +873,7 @@ pub fn execute_input_stream( /// This function iterates over the specified column indices and ensures that none /// of the columns contain null values. If any column contains null values, an error /// is returned. -pub fn check_not_null_contraits( +pub fn check_not_null_constraints( batch: RecordBatch, column_indices: &Vec, ) -> Result { @@ -870,7 +886,13 @@ pub fn check_not_null_contraits( ); } - if batch.column(index).null_count() > 0 { + if batch + .column(index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default() + > 0 + { return exec_err!( "Invalid batch column at '{}' has null but schema specifies non-nullable", index @@ -888,14 +910,28 @@ pub fn get_plan_string(plan: &Arc) -> Vec { actual.iter().map(|elem| elem.to_string()).collect() } +/// Indicates the effect an execution plan operator will have on the cardinality +/// of its input stream +pub enum CardinalityEffect { + /// Unknown effect. This is the default + Unknown, + /// The operator is guaranteed to produce exactly one row for + /// each input row + Equal, + /// The operator may produce fewer output rows than it receives input rows + LowerEqual, + /// The operator may produce more output rows than it receives input rows + GreaterEqual, +} + #[cfg(test)] mod tests { use super::*; + use arrow_array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; - use arrow_schema::{Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -1039,6 +1075,136 @@ mod tests { fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { let _ = plan.name(); } -} -// pub mod test; + #[test] + fn test_check_not_null_constraints_accept_non_null() -> Result<()> { + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_reject_null() -> Result<()> { + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_run_end_array() -> Result<()> { + // some null value inside REE array + let run_ends = Int32Array::from(vec![1, 2, 3, 4]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let run_end_array = RunArray::try_new(&run_ends, &values)?; + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + run_end_array.data_type().to_owned(), + true, + )])), + vec![Arc::new(run_end_array)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_array_with_null() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])); + let keys = Int32Array::from(vec![0, 1, 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_masking_null() -> Result<()> { + // some null value marked out by dictionary array + let values = Arc::new(Int32Array::from(vec![ + Some(1), + None, // this null value is masked by dictionary keys + Some(3), + Some(4), + ])); + let keys = Int32Array::from(vec![0, /*1,*/ 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_on_null_type() -> Result<()> { + // null value of Null type + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Null, true)])), + vec![Arc::new(NullArray::new(3))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{}' to start with '{}'", + actual, + expected_prefix + ); + } +} diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 56dc35e8819d5..96f55a1446b0b 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -67,7 +67,7 @@ impl ExplainExec { &self.stringified_plans } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } @@ -112,7 +112,7 @@ impl ExecutionPlan for ExplainExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 417d2098b0832..30b0af19f43b1 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -48,6 +48,7 @@ use datafusion_physical_expr::{ analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, }; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -114,7 +115,7 @@ impl FilterExec { /// Return new instance of [FilterExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { - // check if the projection is valid + // Check if the projection is valid can_project(&self.schema(), projection.as_ref())?; let projection = match projection { @@ -156,7 +157,7 @@ impl FilterExec { self.default_selectivity } - /// projection + /// Projection pub fn projection(&self) -> Option<&Vec> { self.projection.as_ref() } @@ -254,9 +255,9 @@ impl FilterExec { let expr = Arc::new(column) as _; ConstExpr::new(expr).with_across_partitions(true) }); - // this is for statistics + // This is for statistics eq_properties = eq_properties.with_constants(constants); - // this is for logical constant (for example: a = '1', then a could be marked as a constant) + // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) eq_properties = eq_properties.with_constants(Self::extend_constants(input, predicate)); @@ -330,7 +331,7 @@ impl ExecutionPlan for FilterExec { } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -372,6 +373,10 @@ impl ExecutionPlan for FilterExec { fn statistics(&self) -> Result { Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity) } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } /// This function ensures that all bounds in the `ExprBoundaries` vector are @@ -420,7 +425,7 @@ struct FilterExecStream { predicate: Arc, /// The input partition to filter. input: SendableRecordBatchStream, - /// runtime metrics recording + /// Runtime metrics recording baseline_metrics: BaselineMetrics, /// The projection indices of the columns in the input schema projection: Option>, @@ -444,7 +449,7 @@ fn filter_and_project( .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(match (as_boolean_array(&array), projection) { - // apply filter array to record batch + // Apply filter array to record batch (Ok(filter_array), None) => filter_record_batch(batch, filter_array)?, (Ok(filter_array), Some(projection)) => { let projected_columns = projection @@ -485,7 +490,7 @@ impl Stream for FilterExecStream { &self.schema, )?; timer.done(); - // skip entirely filtered batches + // Skip entirely filtered batches if filtered_batch.num_rows() == 0 { continue; } @@ -502,7 +507,7 @@ impl Stream for FilterExecStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 5dc27bc239d26..8b3ef5ae01e47 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -93,7 +93,7 @@ pub struct DataSinkExec { cache: PlanProperties, } -impl fmt::Debug for DataSinkExec { +impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DataSinkExec schema: {:?}", self.count_schema) } @@ -148,11 +148,7 @@ impl DataSinkExec { } impl DisplayAs for DataSinkExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "DataSinkExec: sink=")?; @@ -271,7 +267,7 @@ fn make_count_batch(count: u64) -> RecordBatch { } fn make_count_schema() -> SchemaRef { - // define a schema. + // Define a schema. Arc::new(Schema::new(vec![Field::new( "count", DataType::UInt64, diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 11153556f2538..8f49885068fd3 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,7 +19,8 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, BatchSplitter, BatchTransformer, + BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -69,16 +70,24 @@ impl CrossJoinExec { /// Create a new [CrossJoinExec]. pub fn new(left: Arc, right: Arc) -> Self { // left then right - let all_columns: Fields = { + let (all_columns, metadata) = { let left_schema = left.schema(); let right_schema = right.schema(); let left_fields = left_schema.fields().iter(); let right_fields = right_schema.fields().iter(); - left_fields.chain(right_fields).cloned().collect() + + let mut metadata = left_schema.metadata().clone(); + metadata.extend(right_schema.metadata().clone()); + + ( + left_fields.chain(right_fields).cloned().collect::(), + metadata, + ) }; - let schema = Arc::new(Schema::new(all_columns)); + let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -239,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( Arc::clone(&self.left), @@ -248,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec { ) }); - Ok(Box::pin(CrossJoinStream { - schema: Arc::clone(&self.schema), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -312,7 +339,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -327,9 +354,11 @@ struct CrossJoinStream { state: CrossJoinStreamState, /// Left data left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -383,24 +412,24 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { + ) -> Poll>> { loop { return match self.state { CrossJoinStreamState::WaitBuildSide => { @@ -463,21 +492,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 48d648c89a354..2d11e03814a31 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -18,6 +18,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator use std::fmt; +use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; @@ -415,6 +416,12 @@ impl HashJoinExec { &self.join_type } + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + /// The partitioning mode of this hash join pub fn partition_mode(&self) -> &PartitionMode { &self.mode @@ -843,7 +850,7 @@ async fn collect_left_input( // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - let fixed_size = std::mem::size_of::(); + let fixed_size = size_of::(); let estimated_hashtable_size = estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; @@ -1432,7 +1439,7 @@ impl HashJoinStream { index_alignment_range_start..index_alignment_range_end, self.join_type, self.right_side_ordered, - ); + )?; let result = build_batch_from_indices( &self.schema, @@ -1518,7 +1525,7 @@ impl Stream for HashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } @@ -3588,10 +3595,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::common::collect(stream) - .await - .unwrap_err() - .to_string(); + let result_string = common::collect(stream).await.unwrap_err().to_string(); assert!( result_string.contains("bad data error"), "actual: {result_string}" diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 029003374accf..358ff02473a67 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; +use super::utils::{ + asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, + BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -35,8 +38,8 @@ use crate::joins::utils::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType, + Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -45,7 +48,9 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; @@ -230,10 +235,11 @@ impl NestedLoopJoinExec { asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; - } + let mode = if left.execution_mode().is_unbounded() { + ExecutionMode::PipelineBreaking + } else { + execution_mode_from_children([left, right]) + }; PlanProperties::new(eq_properties, output_partitioning, mode) } @@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ) }); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - Ok(Box::pin(NestedLoopJoinStream { - schema: Arc::clone(&self.schema), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - indices_cache, - right_side_ordered, - })) + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -442,8 +472,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -454,8 +513,6 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -466,6 +523,12 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -544,107 +607,164 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); + + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + } - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } } + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } - } - }) + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } @@ -684,7 +804,7 @@ fn join_left_and_right_batch( 0..right_batch.num_rows(), join_type, right_side_ordered, - ); + )?; build_batch_from_indices( schema, @@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap( get_final_indices_from_bit_map(&bitmap, join_type) } -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -780,7 +900,7 @@ mod tests { }; sort_info.push(sort_expr); } - exec = exec.with_sort_information(vec![sort_info]); + exec = exec.try_with_sort_information(vec![sort_info]).unwrap(); } Arc::new(exec) @@ -850,7 +970,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2118c1a5266fb..b299b495c5044 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -26,21 +26,20 @@ use std::collections::{HashMap, VecDeque}; use std::fmt::Formatter; use std::fs::File; use std::io::BufReader; -use std::mem; +use std::mem::size_of; use std::ops::Range; use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::{self, concat_batches, take, SortOptions}; +use arrow::compute::{self, concat_batches, filter_record_batch, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; - use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -52,6 +51,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion_physical_expr_common::sort_expr::LexRequirement; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -410,13 +411,13 @@ struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: metrics::Time, /// Number of batches consumed by this operator - input_batches: metrics::Count, + input_batches: Count, /// Number of rows consumed by this operator - input_rows: metrics::Count, + input_rows: Count, /// Number of batches produced by this operator - output_batches: metrics::Count, + output_batches: Count, /// Number of rows produced by this operator - output_rows: metrics::Count, + output_rows: Count, /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, @@ -629,9 +630,9 @@ impl BufferedBatch { .iter() .map(|arr| arr.get_array_memory_size()) .sum::() - + batch.num_rows().next_power_of_two() * mem::size_of::() - + mem::size_of::>() - + mem::size_of::(); + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); let num_rows = batch.num_rows(); BufferedBatch { @@ -687,7 +688,7 @@ struct SMJStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: Vec, + pub output_record_batches: JoinedRecordBatches, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -702,6 +703,22 @@ struct SMJStream { pub reservation: MemoryReservation, /// Runtime env pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Joined batches with attached join filter information +struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, } impl RecordBatchStream for SMJStream { @@ -710,6 +727,112 @@ impl RecordBatchStream for SMJStream { } } +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || batch_ids[row_index] != batch_ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + JoinType::LeftAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(true); null_matched]); + Some(corrected_mask.finish()) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, + } +} + impl Stream for SMJStream { type Item = Result; @@ -719,7 +842,6 @@ impl Stream for SMJStream { ) -> Poll> { let join_time = self.join_metrics.join_time.clone(); let _timer = join_time.timer(); - loop { match &self.state { SMJState::Init => { @@ -733,6 +855,27 @@ impl Stream for SMJStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) + { + self.freeze_all()?; + + if !self.output_record_batches.batches.is_empty() + { + let out_filtered_batch = + self.filter_joined_batch()?; + return Poll::Ready(Some(Ok( + out_filtered_batch, + ))); + } + } + self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -786,8 +929,25 @@ impl Stream for SMJStream { } } else { self.freeze_all()?; - if !self.output_record_batches.is_empty() { + if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) + { + continue; + } + return Poll::Ready(Some(Ok(record_batch))); } return Poll::Pending; @@ -795,11 +955,26 @@ impl Stream for SMJStream { } SMJState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if !self.output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) + { + let out = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else { + return Poll::Ready(None); } - return Poll::Ready(None); } } } @@ -844,13 +1019,19 @@ impl SMJStream { on_streamed, on_buffered, filter, - output_record_batches: vec![], + output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, output_size: 0, batch_size, join_type, join_metrics, reservation, runtime_env, + streamed_batch_counter: AtomicUsize::new(0), }) } @@ -882,6 +1063,10 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; } } @@ -1062,14 +1247,14 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size @@ -1122,11 +1307,7 @@ impl SMJStream { }; if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; + join_streamed = !self.streamed_joined; join_buffered = join_streamed; } } @@ -1228,7 +1409,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); @@ -1251,7 +1432,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.join_filter_failed_map.clear(); } @@ -1300,7 +1481,6 @@ impl SMJStream { }; let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. @@ -1329,15 +1509,14 @@ impl SMJStream { }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); + buffered_columns.extend(streamed_columns); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { @@ -1367,61 +1546,54 @@ impl SMJStream { pre_mask.clone() }; - // For certain join types, we need to adjust the initial mask to handle the join filter. - let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask( - self.join_type, - &streamed_indices, - &mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ); - - let mask = - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - &filtered_join_mask.0 - } else { - &mask - }; - // Push the filtered batch which contains rows passing join filter to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); + if matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) { + self.output_record_batches + .batches + .push(output_batch.clone()); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.output_record_batches.batches.push(filtered_batch); + } + + self.output_record_batches.filter_mask.extend(&mask); + self.output_record_batches + .row_indices + .extend(&streamed_indices); + self.output_record_batches.batch_ids.extend(vec![ + self.streamed_batch_counter.load(Relaxed); + streamed_indices.len() + ]); // For outer joins, we need to push the null joined rows to the output if // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!( - self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full - ) { + if matches!(self.join_type, JoinType::Full) { // We need to get the mask for row indices that the joined rows are failed // on the join filter. I.e., for a row in streamed side, if all joined rows // between it and all buffered rows are failed on the join filter, we need to // output it with null columns from buffered side. For the mask here, it // behaves like LeftAnti join. - let null_mask: BooleanArray = get_filtered_join_mask( - // Set a mask slot as true only if all joined rows of same streamed index - // are failed on the join filter. - // The masking behavior is like LeftAnti join. - JoinType::LeftAnti, - &streamed_indices, - mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ) - .unwrap() - .0; + let not_mask = if mask.null_count() > 0 { + // If the mask contains nulls, we need to use `prep_null_mask_filter` to + // handle the nulls in the mask as false to produce rows where the mask + // was null itself. + compute::not(&compute::prep_null_mask_filter(&mask))? + } else { + compute::not(&mask)? + }; let null_joined_batch = - compute::filter_record_batch(&output_batch, &null_mask)?; + filter_record_batch(&output_batch, ¬_mask)?; - let mut buffered_columns = self + let buffered_columns = self .buffered_schema .fields() .iter() @@ -1433,18 +1605,7 @@ impl SMJStream { }) .collect::>(); - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join + let columns = { let mut streamed_columns = null_joined_batch .columns() .iter() @@ -1457,11 +1618,12 @@ impl SMJStream { }; // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; - self.output_record_batches.push(null_joined_streamed_batch); + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + self.output_record_batches + .batches + .push(null_joined_streamed_batch); // For full join, we also need to output the null joined rows from the buffered side. // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with @@ -1494,10 +1656,10 @@ impl SMJStream { } } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } @@ -1507,7 +1669,8 @@ impl SMJStream { } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; + let record_batch = + concat_batches(&self.schema, &self.output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact @@ -1520,9 +1683,98 @@ impl SMJStream { } else { self.output_size -= record_batch.num_rows(); } - self.output_record_batches.clear(); + + if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + )) + { + self.output_record_batches.batches.clear(); + } Ok(record_batch) } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = self.output_record_batch_and_reset()?; + let out_indices = self.output_record_batches.row_indices.finish(); + let out_mask = self.output_record_batches.filter_mask.finish(); + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + &self.output_record_batches.batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let buffered_columns_length = self.buffered_schema.fields.len(); + let streamed_columns_length = self.streamed_schema.fields.len(); + + if matches!(self.join_type, JoinType::Left | JoinType::Right) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), null_joined_batch.num_rows())) + .collect::>(); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + // Left join or full outer join + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + let output_column_indices = (0..streamed_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } + + self.output_record_batches.batches.clear(); + self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.filter_mask = BooleanBuilder::new(); + self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) + } } /// Gets the arrays which join filters are applied on. @@ -1631,101 +1883,6 @@ fn get_buffered_columns_from_batch( } } -/// Calculate join filter bit mask considering join type specifics -/// `streamed_indices` - array of streamed datasource JOINED row indices -/// `mask` - array booleans representing computed join filter expression eval result: -/// true = the row index matches the join filter -/// false = the row index doesn't match the join filter -/// `streamed_indices` have the same length as `mask` -/// `matched_indices` array of streaming indices that already has a join filter match -/// `scanning_buffered_offset` current buffered offset across batches -/// -/// This return a tuple of: -/// - corrected mask with respect to the join type -/// - indices of rows in streamed batch that have a join filter match -fn get_filtered_join_mask( - join_type: JoinType, - streamed_indices: &UInt64Array, - mask: &BooleanArray, - matched_indices: &HashSet, - scanning_buffered_offset: &usize, -) -> Option<(BooleanArray, Vec)> { - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - match join_type { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - JoinType::LeftSemi => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_idx); - } else { - corrected_mask.append_value(false); - } - - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - Some((corrected_mask.finish(), filter_matched_indices)) - } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. - // `filter_matched_indices` needs to be set once per streaming index - // to prevent duplicates in the output - JoinType::LeftAnti => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - filter_matched_indices.push(streamed_idx); - } - - // Reset `seen_as_true` flag and calculate mask for the current streaming index - // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) - // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last - if (i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) - { - corrected_mask.append_value( - !matched_indices.contains(&streamed_idx) && !seen_as_true, - ); - seen_as_true = false; - } else { - corrected_mask.append_value(false); - } - } - - Some((corrected_mask.finish(), filter_matched_indices)) - } - _ => None, - } -} - /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1966,13 +2123,13 @@ mod tests { use std::sync::Arc; use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; - use hashbrown::HashSet; - use datafusion_common::JoinType::{LeftAnti, LeftSemi}; + use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -1982,7 +2139,7 @@ mod tests { use datafusion_execution::TaskContext; use crate::expressions::Column; - use crate::joins::sort_merge_join::get_filtered_join_mask; + use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; use crate::joins::utils::JoinOn; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; @@ -2175,7 +2332,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", @@ -2214,7 +2371,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2252,7 +2409,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2291,7 +2448,7 @@ mod tests { ), ]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2332,7 +2489,7 @@ mod tests { left, right, on, - JoinType::Inner, + Inner, vec![ SortOptions { descending: true, @@ -2382,7 +2539,7 @@ mod tests { ]; let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + join_collect_batch_size_equals_two(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2417,7 +2574,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2449,7 +2606,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2481,7 +2638,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2513,7 +2670,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; + let (_, batches) = join_collect(left, right, on, LeftAnti).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2544,7 +2701,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; + let (_, batches) = join_collect(left, right, on, LeftSemi).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2577,7 +2734,7 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", @@ -2609,7 +2766,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+------------+------------+------------+------------+------------+------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2641,7 +2798,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2672,7 +2829,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2708,7 +2865,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2752,7 +2909,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2801,7 +2958,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2850,7 +3007,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2890,14 +3047,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -2975,14 +3125,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3038,14 +3181,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3146,14 +3282,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3213,171 +3342,677 @@ mod tests { Ok(()) } + fn build_joined_record_batches() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut batches = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + // Insert already prejoined non-filtered rows + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + batches + .filter_mask + .extend(&BooleanArray::from(vec![true, false])); + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, true])); + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + Ok(batches) + } + #[tokio::test] - async fn left_semi_join_filtered_mask() -> Result<()> { + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, false, false, false, false, false, false, false + ]) + ); + assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + false, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, true, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + Left, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[null_joined_batch] + ); + Ok(()) + } + + #[tokio::test] + async fn test_left_semi_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, true]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![1])) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![0])) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, true, false, true, false, false]), - vec![0, 1] - )) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, true]), - vec![1] - )) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![true, false, false, false, false, true]), - &HashSet::from_iter(vec![1]), - &0, - ), - Some(( - BooleanArray::from(vec![true, false, false, false, false, false]), - vec![0] - )) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftSemi, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); Ok(()) } #[tokio::test] - async fn left_anti_join_filtered_mask() -> Result<()> { + async fn test_left_anti_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![1])) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![0])) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, false]), - vec![0, 1] - )) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) ); assert_eq!( - get_filtered_join_mask( + get_corrected_filter_mask( LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, true, false, false, false]), - vec![1] - )) + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) ); + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftAnti, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); Ok(()) } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index ba9384aef1a65..02c71dab3df23 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -19,6 +19,7 @@ //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; +use std::mem::size_of; use std::sync::Arc; use crate::joins::utils::{JoinFilter, JoinHashMapType}; @@ -31,8 +32,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -154,8 +154,7 @@ impl PruningJoinHashMap { /// # Returns /// The size of the hash map in bytes. pub(crate) fn size(&self) -> usize { - self.map.allocation_info().1.size() - + self.next.capacity() * std::mem::size_of::() + self.map.allocation_info().1.size() + self.next.capacity() * size_of::() } /// Removes hash values from the map and the list based on the given pruning @@ -369,34 +368,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -409,41 +414,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -456,13 +465,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs( left_sort_exprs: &[PhysicalSortExpr], right_sort_exprs: &[PhysicalSortExpr], ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -952,15 +971,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4c..eb6a30d17e925 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -27,12 +27,12 @@ use std::any::Any; use std::fmt::{self, Debug}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; -use crate::handle_state; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, @@ -42,8 +42,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ execution_mode_from_children, @@ -465,23 +466,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -494,6 +499,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), @@ -502,29 +511,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: SHJStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -556,20 +588,24 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// State machine for input execution state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { self.poll_next_impl(cx) } @@ -969,15 +1005,15 @@ pub struct OneSideHashJoiner { impl OneSideHashJoiner { pub fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(self); - size += std::mem::size_of_val(&self.build_side); + size += size_of_val(self); + size += size_of_val(&self.build_side); size += self.input_buffer.get_array_memory_size(); - size += std::mem::size_of_val(&self.on); + size += size_of_val(&self.on); size += self.hashmap.size(); - size += self.hashes_buffer.capacity() * std::mem::size_of::(); - size += self.visited_rows.capacity() * std::mem::size_of::(); - size += std::mem::size_of_val(&self.offset); - size += std::mem::size_of_val(&self.deleted_offset); + size += self.hashes_buffer.capacity() * size_of::(); + size += self.visited_rows.capacity() * size_of::(); + size += size_of_val(&self.offset); + size += size_of_val(&self.deleted_offset); size } pub fn new( @@ -1140,7 +1176,7 @@ impl OneSideHashJoiner { /// - Transition to `BothExhausted { final_result: true }`: /// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are /// exhausted, indicating completion of processing and availability of final results. -impl SymmetricHashJoinStream { +impl SymmetricHashJoinStream { /// Implements the main polling logic for the join stream. /// /// This method continuously checks the state of the join stream and @@ -1159,26 +1195,45 @@ impl SymmetricHashJoinStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - return match self.state() { - SHJStreamState::PullRight => { - handle_state!(ready!(self.fetch_next_from_right_stream(cx))) - } - SHJStreamState::PullLeft => { - handle_state!(ready!(self.fetch_next_from_left_stream(cx))) - } - SHJStreamState::RightExhausted => { - handle_state!(ready!(self.handle_right_stream_end(cx))) - } - SHJStreamState::LeftExhausted => { - handle_state!(ready!(self.handle_left_stream_end(cx))) + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } } - SHJStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); } - SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), - }; + } } } /// Asynchronously pulls the next batch from the right stream. @@ -1384,11 +1439,8 @@ impl SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1412,18 +1464,18 @@ impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.schema); - size += std::mem::size_of_val(&self.filter); - size += std::mem::size_of_val(&self.join_type); + size += size_of_val(&self.schema); + size += size_of_val(&self.filter); + size += size_of_val(&self.join_type); size += self.left.size(); size += self.right.size(); - size += std::mem::size_of_val(&self.column_indices); + size += size_of_val(&self.column_indices); size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); - size += std::mem::size_of_val(&self.left_sorted_filter_expr); - size += std::mem::size_of_val(&self.right_sorted_filter_expr); - size += std::mem::size_of_val(&self.random_state); - size += std::mem::size_of_val(&self.null_equals_null); - size += std::mem::size_of_val(&self.metrics); + size += size_of_val(&self.left_sorted_filter_expr); + size += size_of_val(&self.right_sorted_filter_expr); + size += size_of_val(&self.random_state); + size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.metrics); size } @@ -1523,11 +1575,6 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } @@ -1716,15 +1763,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1771,10 +1818,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1825,10 +1869,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1877,10 +1918,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1926,10 +1964,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1987,10 +2022,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2048,10 +2080,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2111,10 +2140,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2170,10 +2196,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2237,10 +2260,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2296,10 +2316,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { @@ -2380,10 +2397,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { @@ -2473,10 +2487,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 264f297ffb4c4..090d60f0bac3d 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -289,7 +289,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(10 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + // left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15 1 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -300,9 +300,9 @@ macro_rules! join_expr_tests { Operator::Plus, ), ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(15 as $type)), (Operator::Gt, Operator::Lt), ), // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 @@ -353,7 +353,8 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + // left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3 + // (filters all input rows) 5 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -369,7 +370,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::GtEq, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + // left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39 6 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -385,7 +386,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(39 as $type)), (Operator::Gt, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + // left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 7 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -526,10 +527,10 @@ pub fn create_memory_table( ) -> Result<(Arc, Arc)> { let left_schema = left_partition[0].schema(); let left = MemoryExec::try_new(&[left_partition], left_schema, None)? - .with_sort_information(left_sorted); + .try_with_sort_information(left_sorted)?; let right_schema = right_partition[0].schema(); let right = MemoryExec::try_new(&[right_partition], right_schema, None)? - .with_sort_information(right_sorted); + .try_with_sort_information(right_sorted)?; Ok((Arc::new(left), Arc::new(right))) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be6..090cf9aa628a7 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -369,7 +369,7 @@ impl JoinHashMapType for JoinHashMap { } } -impl fmt::Debug for JoinHashMap { +impl Debug for JoinHashMap { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } @@ -546,15 +546,16 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. #[derive(Debug, Clone)] pub struct JoinFilter { /// Filter expression - expression: Arc, + pub(crate) expression: Arc, /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, + pub(crate) column_indices: Vec, /// Physical schema of intermediate batch - schema: Schema, + pub(crate) schema: Schema, } impl JoinFilter { @@ -700,7 +701,13 @@ pub fn build_join_schema( .unzip(), }; - (fields.finish(), column_indices) + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); + (fields.finish().with_metadata(metadata), column_indices) } /// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls @@ -720,8 +727,8 @@ impl Default for OnceAsync { } } -impl std::fmt::Debug for OnceAsync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for OnceAsync { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "OnceAsync") } } @@ -1280,15 +1287,15 @@ pub(crate) fn adjust_indices_by_join_type( adjust_range: Range, join_type: JoinType, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } JoinType::Right => { @@ -1307,22 +1314,22 @@ pub(crate) fn adjust_indices_by_join_type( // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::LeftSemi | JoinType::LeftAnti => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } @@ -1347,27 +1354,64 @@ pub(crate) fn append_right_indices( right_indices: UInt32Array, adjust_range: Range, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { if preserve_order_for_right { - append_probe_indices_in_order(left_indices, right_indices, adjust_range) + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + Ok((left_indices, right_indices)) } else { - let unmatched_size = right_unmatched_indices.len(); + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect(); - (new_left_indices, new_right_indices) + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) } } } @@ -1635,6 +1679,91 @@ pub(crate) fn asymmetric_join_output_partitioning( } } +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, + } + } +} + +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1643,11 +1772,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -1821,13 +1952,13 @@ mod tests { ) -> Statistics { Statistics { num_rows: if is_exact { - num_rows.map(Precision::Exact) + num_rows.map(Exact) } else { - num_rows.map(Precision::Inexact) + num_rows.map(Inexact) } - .unwrap_or(Precision::Absent), + .unwrap_or(Absent), column_statistics: column_stats, - total_byte_size: Precision::Absent, + total_byte_size: Absent, } } @@ -2073,17 +2204,17 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact((400 * 400) / 200)) + Some(Inexact((400 * 400) / 200)) ); Ok(()) } @@ -2091,33 +2222,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact(100)) + Some(Inexact(100)) ); Ok(()) } @@ -2554,4 +2685,49 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 7cbfd49afb863..845a74eaea48e 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -82,6 +82,7 @@ pub mod windows; pub mod work_table; pub mod udaf { + pub use datafusion_expr::StatisticsArgs; pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; } diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 360e942226d24..1fe550a930561 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -336,6 +337,10 @@ impl ExecutionPlan for LocalLimitExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows. @@ -393,7 +398,7 @@ impl LimitStream { if batch.num_rows() > 0 { break poll; } else { - // continue to poll input stream + // Continue to poll input stream } } Poll::Ready(Some(Err(_e))) => break poll, @@ -403,12 +408,12 @@ impl LimitStream { } } - /// fetches from the batch + /// Fetches from the batch fn stream_limit(&mut self, batch: RecordBatch) -> Option { // records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); if self.fetch == 0 { - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early None } else if batch.num_rows() < self.fetch { // @@ -417,7 +422,7 @@ impl LimitStream { } else if batch.num_rows() >= self.fetch { let batch_rows = self.fetch; self.fetch = 0; - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early // It is guaranteed that batch_rows is <= batch.num_rows Some(batch.slice(0, batch_rows)) @@ -448,7 +453,7 @@ impl Stream for LimitStream { other => other, }) } - // input has been cleared + // Input has been cleared None => Poll::Ready(None), }; @@ -468,7 +473,7 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::collect; - use crate::{common, test}; + use crate::test; use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use arrow_array::RecordBatchOptions; @@ -484,17 +489,17 @@ mod tests { let num_partitions = 4; let csv = test::scan_partitioned(num_partitions); - // input should have 4 partitions + // Input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let limit = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7)); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = limit.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; - // there should be a total of 100 rows + // There should be a total of 100 rows let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); assert_eq!(row_count, 7); @@ -515,7 +520,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (5 rows) and 1 row from the second (1 row) let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -545,7 +550,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -575,7 +580,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -593,7 +598,7 @@ mod tests { Ok(()) } - // test cases for "skip" + // Test cases for "skip" async fn skip_and_fetch(skip: usize, fetch: Option) -> Result { let task_ctx = Arc::new(TaskContext::default()); @@ -606,9 +611,9 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = offset.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } @@ -628,7 +633,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 3 rows (offset = 3) + // There are total of 400 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, None).await?; assert_eq!(row_count, 397); Ok(()) @@ -636,7 +641,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_10_stats() -> Result<()> { - // there are total of 100 rows, we skipped 3 rows (offset = 3) + // There are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); Ok(()) @@ -651,7 +656,7 @@ mod tests { #[tokio::test] async fn skip_400_fetch_1() -> Result<()> { - // there are a total of 400 rows + // There are a total of 400 rows let row_count = skip_and_fetch(400, Some(1)).await?; assert_eq!(row_count, 0); Ok(()) @@ -659,7 +664,7 @@ mod tests { #[tokio::test] async fn skip_401_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 401 rows (offset = 3) + // There are total of 400 rows, we skipped 401 rows (offset = 3) let row_count = skip_and_fetch(401, None).await?; assert_eq!(row_count, 0); Ok(()) diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 3aa445d295cb0..dd4868d1bfcc9 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -33,6 +33,9 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, project_schema, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use futures::Stream; @@ -66,11 +69,7 @@ impl fmt::Debug for MemoryExec { } impl DisplayAs for MemoryExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let partition_sizes: Vec<_> = @@ -116,7 +115,7 @@ impl ExecutionPlan for MemoryExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } @@ -176,7 +175,7 @@ impl MemoryExec { }) } - /// set `show_sizes` to determine whether to display partition sizes + /// Set `show_sizes` to determine whether to display partition sizes pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { self.show_sizes = show_sizes; self @@ -206,16 +205,63 @@ impl MemoryExec { /// where both `a ASC` and `b DESC` can describe the table ordering. With /// [`EquivalenceProperties`], we can keep track of these equivalences /// and treat `a ASC` and `b DESC` as the same ordering requirement. - pub fn with_sort_information(mut self, sort_information: Vec) -> Self { - self.sort_information = sort_information; + /// + /// Note that if there is an internal projection, that projection will be + /// also applied to the given `sort_information`. + pub fn try_with_sort_information( + mut self, + mut sort_information: Vec, + ) -> Result { + // All sort expressions must refer to the original schema + let fields = self.schema.fields(); + let ambiguous_column = sort_information + .iter() + .flatten() + .flat_map(|expr| collect_columns(&expr.expr)) + .find(|col| { + fields + .get(col.index()) + .map(|field| field.name() != col.name()) + .unwrap_or(true) + }); + if let Some(col) = ambiguous_column { + return internal_err!( + "Column {:?} is not found in the original schema of the MemoryExec", + col + ); + } + // If there is a projection on the source, we also need to project orderings + if let Some(projection) = &self.projection { + let base_eqp = EquivalenceProperties::new_with_orderings( + self.original_schema(), + &sort_information, + ); + let proj_exprs = projection + .iter() + .map(|idx| { + let base_schema = self.original_schema(); + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }) + .collect::>(); + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; + sort_information = base_eqp + .project(&projection_mapping, self.schema()) + .oeq_class + .orderings; + } + + self.sort_information = sort_information; // We need to update equivalence properties when updating sort information. let eq_properties = EquivalenceProperties::new_with_orderings( self.schema(), &self.sort_information, ); self.cache = self.cache.with_eq_properties(eq_properties); - self + + Ok(self) } pub fn original_schema(&self) -> SchemaRef { @@ -347,7 +393,7 @@ mod tests { let sort_information = vec![sort1.clone(), sort2.clone()]; let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)? - .with_sort_information(sort_information); + .try_with_sort_information(sort_information)?; assert_eq!( mem_exec.properties().output_ordering().unwrap(), diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index 22db8f1e4e886..2eb01914ee0ac 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -37,7 +37,7 @@ use parking_lot::Mutex; #[derive(Debug, Clone)] pub struct Count { /// value of the metric counter - value: std::sync::Arc, + value: Arc, } impl PartialEq for Count { @@ -86,7 +86,7 @@ impl Count { #[derive(Debug, Clone)] pub struct Gauge { /// value of the metric gauge - value: std::sync::Arc, + value: Arc, } impl PartialEq for Gauge { @@ -168,7 +168,7 @@ impl PartialEq for Time { impl Display for Time { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let duration = std::time::Duration::from_nanos(self.value() as u64); + let duration = Duration::from_nanos(self.value() as u64); write!(f, "{duration:?}") } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 272211d5056e0..5d8ca7e76935e 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -208,7 +208,7 @@ mod tests { let schema = test::aggr_test_schema(); let placeholder = PlaceholderRowExec::new(schema); - // ask for the wrong partition + // Ask for the wrong partition assert!(placeholder.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(placeholder.execute(20, task_ctx).is_err()); Ok(()) @@ -223,7 +223,7 @@ mod tests { let iter = placeholder.execute(0, task_ctx)?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); Ok(()) @@ -240,7 +240,7 @@ mod tests { let iter = placeholder.execute(n, Arc::clone(&task_ctx))?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index f1b9cdaf728ff..c1d3f368366f6 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -40,8 +40,9 @@ use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{CastExpr, Literal}; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -89,7 +90,7 @@ impl ProjectionExec { input_schema.metadata().clone(), )); - // construct a map from the input expressions to the output expression of the Projection + // Construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; let cache = Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; @@ -182,7 +183,7 @@ impl ExecutionPlan for ProjectionExec { } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -233,14 +234,22 @@ impl ExecutionPlan for ProjectionExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } -/// If e is a direct column reference, returns the field level +/// If 'e' is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None -fn get_field_metadata( +pub(crate) fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { + if let Some(cast) = e.as_any().downcast_ref::() { + return get_field_metadata(cast.expr(), input_schema); + } + // Look up field by index in schema (not NAME as there can be more than one // column with the same name) e.as_any() @@ -285,7 +294,7 @@ fn stats_projection( impl ProjectionStream { fn batch_project(&self, batch: &RecordBatch) -> Result { - // records time on drop + // Records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); let arrays = self .expr @@ -331,7 +340,7 @@ impl Stream for ProjectionStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } @@ -347,7 +356,6 @@ impl RecordBatchStream for ProjectionStream { mod tests { use super::*; use crate::common::collect; - use crate::expressions; use crate::test; use arrow_schema::DataType; @@ -409,8 +417,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col1", 1)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col1", 1)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); @@ -443,8 +451,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col2", 2)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col2", 2)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs index 675d26bbfb9fc..2e5ef24beac31 100644 --- a/datafusion/physical-plan/src/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -829,7 +829,7 @@ mod tests { { let test_waker = Arc::new(TestWaker::default()); let waker = futures::task::waker(Arc::clone(&test_waker)); - let mut cx = std::task::Context::from_waker(&waker); + let mut cx = Context::from_waker(&waker); let res = fut.poll_unpin(&mut cx); (res, test_waker) } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 093803e3c8b30..601c1e8731523 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,21 +34,22 @@ use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; -use crate::sorts::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; -use arrow::array::ArrayRef; -use arrow::datatypes::{SchemaRef, UInt64Type}; +use arrow::compute::take_arrays; +use arrow::datatypes::{SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::{PrimitiveArray, RecordBatchOptions}; use datafusion_common::utils::transpose; -use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; +use crate::execution_plan::CardinalityEffect; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; @@ -280,7 +281,7 @@ impl BatchPartitioner { .collect(); for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize].push(index as u64); + indices[(*hash % *partitions as u64) as usize].push(index as u32); } // Finished building index-arrays for output partitions @@ -292,7 +293,7 @@ impl BatchPartitioner { .into_iter() .enumerate() .filter_map(|(partition, indices)| { - let indices: PrimitiveArray = indices.into(); + let indices: PrimitiveArray = indices.into(); (!indices.is_empty()).then_some((partition, indices)) }) .map(move |(partition, indices)| { @@ -300,14 +301,7 @@ impl BatchPartitioner { let _timer = partitioner_timer.timer(); // Produce batches based on indices - let columns = batch - .columns() - .iter() - .map(|c| { - arrow::compute::take(c.as_ref(), &indices, None) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect::>>()?; + let columns = take_arrays(batch.columns(), &indices, None)?; let mut options = RecordBatchOptions::new(); options = options.with_row_count(Some(indices.len())); @@ -385,6 +379,11 @@ impl BatchPartitioner { /// `───────' `───────' ///``` /// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// all output partitions and inputs are not polled again. +/// /// # Output Ordering /// /// If more than one stream is being repartitioned, the output will be some @@ -403,8 +402,6 @@ impl BatchPartitioner { pub struct RepartitionExec { /// Input execution plan input: Arc, - /// Partitioning scheme to use - partitioning: Partitioning, /// Inner state that is initialized when the first output stream is created. state: LazyState, /// Execution metrics @@ -469,7 +466,7 @@ impl RepartitionExec { /// Partitioning scheme to use pub fn partitioning(&self) -> &Partitioning { - &self.partitioning + &self.cache.partitioning } /// Get preserve_order flag of the RepartitionExecutor @@ -496,7 +493,7 @@ impl DisplayAs for RepartitionExec { f, "{}: partitioning={}, input_partitions={}", self.name(), - self.partitioning, + self.partitioning(), self.input.output_partitioning().partition_count() )?; @@ -539,8 +536,10 @@ impl ExecutionPlan for RepartitionExec { self: Arc, mut children: Vec>, ) -> Result> { - let mut repartition = - RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + let mut repartition = RepartitionExec::try_new( + children.swap_remove(0), + self.partitioning().clone(), + )?; if self.preserve_order { repartition = repartition.with_preserve_order(); } @@ -548,7 +547,7 @@ impl ExecutionPlan for RepartitionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - vec![matches!(self.partitioning, Partitioning::Hash(_, _))] + vec![matches!(self.partitioning(), Partitioning::Hash(_, _))] } fn maintains_input_order(&self) -> Vec { @@ -568,7 +567,7 @@ impl ExecutionPlan for RepartitionExec { let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); - let partitioning = self.partitioning.clone(); + let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); let preserve_order = self.preserve_order; let name = self.name().to_owned(); @@ -640,15 +639,15 @@ impl ExecutionPlan for RepartitionExec { let merge_reservation = MemoryConsumer::new(format!("{}[Merge {partition}]", name)) .register(context.memory_pool()); - streaming_merge( - input_streams, - schema_captured, - &sort_exprs, - BaselineMetrics::new(&metrics, partition), - context.session_config().batch_size(), - fetch, - merge_reservation, - ) + StreamingMergeBuilder::new() + .with_streams(input_streams) + .with_schema(schema_captured) + .with_expressions(&sort_exprs) + .with_metrics(BaselineMetrics::new(&metrics, partition)) + .with_batch_size(context.session_config().batch_size()) + .with_fetch(fetch) + .with_reservation(merge_reservation) + .build() } else { Ok(Box::pin(RepartitionStream { num_input_partitions, @@ -672,6 +671,10 @@ impl ExecutionPlan for RepartitionExec { fn statistics(&self) -> Result { self.input.statistics() } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } impl RepartitionExec { @@ -687,7 +690,6 @@ impl RepartitionExec { Self::compute_properties(&input, partitioning.clone(), preserve_order); Ok(RepartitionExec { input, - partitioning, state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, @@ -1027,10 +1029,10 @@ mod tests { {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{StringArray, UInt32Array}; + use arrow::array::{ArrayRef, StringArray, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_string_array; - use datafusion_common::{assert_batches_sorted_eq, exec_err}; + use datafusion_common::{arrow_datafusion_err, assert_batches_sorted_eq, exec_err}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use tokio::task::JoinSet; @@ -1134,7 +1136,7 @@ mod tests { // execute and collect results let mut output_partitions = vec![]; - for i in 0..exec.partitioning.partition_count() { + for i in 0..exec.partitioning().partition_count() { // execute this *output* partition and collect all batches let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let mut batches = vec![]; @@ -1324,7 +1326,7 @@ mod tests { // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); // Now, start sending input let mut background_task = JoinSet::new(); @@ -1399,7 +1401,7 @@ mod tests { let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); let mut background_task = JoinSet::new(); background_task.spawn(async move { input.wait().await; @@ -1524,7 +1526,7 @@ mod tests { let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // pull partitions - for i in 0..exec.partitioning.partition_count() { + for i in 0..exec.partitioning().partition_count() { let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let err = arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); @@ -1676,7 +1678,8 @@ mod test { Arc::new( MemoryExec::try_new(&[vec![]], Arc::clone(schema), None) .unwrap() - .with_sort_information(vec![sort_exprs]), + .try_with_sort_information(vec![sort_exprs]) + .unwrap(), ) } } diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 875922ac34b54..e0644e3d99e55 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -39,6 +39,7 @@ use futures::Stream; /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; +/// Merges a stream of sorted cursors and record batches into a single sorted stream #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index 7c084761fdc30..ab5df37ed327c 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -28,4 +28,3 @@ mod stream; pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 70a63e71ad2f2..649c05d52e8ba 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -104,7 +104,7 @@ impl PartialSortExec { input: Arc, common_prefix_length: usize, ) -> Self { - assert!(common_prefix_length > 0); + debug_assert!(common_prefix_length > 0); let preserve_partitioning = false; let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); Self { @@ -289,7 +289,7 @@ impl ExecutionPlan for PartialSortExec { // Make sure common prefix length is larger than 0 // Otherwise, we should use SortExec. - assert!(self.common_prefix_length > 0); + debug_assert!(self.common_prefix_length > 0); Ok(Box::pin(PartialSortStream { input, diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index fb03ceb15c378..921678a4ad923 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -30,7 +30,7 @@ use crate::limit::LimitStream; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::sorts::streaming_merge::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::{read_spill_as_stream, spill_record_batches}; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; @@ -40,7 +40,7 @@ use crate::{ SendableRecordBatchStream, Statistics, }; -use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; @@ -54,6 +54,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; +use crate::execution_plan::CardinalityEffect; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -341,21 +342,17 @@ impl ExternalSorter { streams.push(stream); } - streaming_merge( - streams, - Arc::clone(&self.schema), - &self.expr, - self.metrics.baseline.clone(), - self.batch_size, - self.fetch, - self.reservation.new_empty(), - ) - } else if !self.in_mem_batches.is_empty() { - self.in_mem_sort_stream(self.metrics.baseline.clone()) + StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_metrics(self.metrics.baseline.clone()) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_reservation(self.reservation.new_empty()) + .build() } else { - Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( - &self.schema, - )))) + self.in_mem_sort_stream(self.metrics.baseline.clone()) } } @@ -500,7 +497,11 @@ impl ExternalSorter { &mut self, metrics: BaselineMetrics, ) -> Result { - assert_ne!(self.in_mem_batches.len(), 0); + if self.in_mem_batches.is_empty() { + return Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))); + } // The elapsed compute timer is updated when the value is dropped. // There is no need for an explicit call to drop. @@ -508,7 +509,7 @@ impl ExternalSorter { let _timer = elapsed_compute.timer(); if self.in_mem_batches.len() == 1 { - let batch = self.in_mem_batches.remove(0); + let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); return self.sort_batch_stream(batch, metrics, reservation); } @@ -533,15 +534,15 @@ impl ExternalSorter { }) .collect::>()?; - streaming_merge( - streams, - Arc::clone(&self.schema), - &self.expr, - metrics, - self.batch_size, - self.fetch, - self.merge_reservation.new_empty(), - ) + StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_metrics(metrics) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_reservation(self.merge_reservation.new_empty()) + .build() } /// Sorts a single `RecordBatch` into a single stream. @@ -616,11 +617,7 @@ pub fn sort_batch( lexsort_to_indices(&sort_columns, fetch)? }; - let columns = batch - .columns() - .iter() - .map(|c| take(c.as_ref(), &indices, None)) - .collect::>()?; + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -818,11 +815,7 @@ impl SortExec { } impl DisplayAs for SortExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr = PhysicalSortExpr::format_list(&self.expr); @@ -975,6 +968,14 @@ impl ExecutionPlan for SortExec { fn fetch(&self) -> Option { self.fetch } + + fn cardinality_effect(&self) -> CardinalityEffect { + if self.fetch.is_none() { + CardinalityEffect::Equal + } else { + CardinalityEffect::LowerEqual + } + } } #[cfg(test)] @@ -1013,7 +1014,7 @@ mod tests { } impl DisplayAs for SortedUnboundedExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "UnboundableExec",).unwrap() diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index f83bb58d08dd1..31a4ed61cf9e8 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -24,7 +24,7 @@ use crate::common::spawn_buffered; use crate::expressions::PhysicalSortExpr; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::sorts::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, @@ -65,6 +65,11 @@ use log::{debug, trace}; /// Input Streams Output stream /// (sorted) (sorted) /// ``` +/// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// the output and inputs are not polled again. #[derive(Debug)] pub struct SortPreservingMergeExec { /// Input plan @@ -268,15 +273,15 @@ impl ExecutionPlan for SortPreservingMergeExec { debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); - let result = streaming_merge( - receivers, - schema, - &self.expr, - BaselineMetrics::new(&self.metrics, partition), - context.session_config().batch_size(), - self.fetch, - reservation, - )?; + let result = StreamingMergeBuilder::new() + .with_streams(receivers) + .with_schema(schema) + .with_expressions(&self.expr) + .with_metrics(BaselineMetrics::new(&self.metrics, partition)) + .with_batch_size(context.session_config().batch_size()) + .with_fetch(self.fetch) + .with_reservation(reservation) + .build()?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -941,7 +946,7 @@ mod tests { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); // This causes the MergeStream to wait for more input - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + tokio::time::sleep(Duration::from_millis(10)).await; } Ok(()) @@ -955,16 +960,15 @@ mod tests { MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool); let fetch = None; - let merge_stream = streaming_merge( - streams, - batches.schema(), - sort.as_slice(), - BaselineMetrics::new(&metrics, 0), - task_ctx.session_config().batch_size(), - fetch, - reservation, - ) - .unwrap(); + let merge_stream = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(batches.schema()) + .with_expressions(sort.as_slice()) + .with_metrics(BaselineMetrics::new(&metrics, 0)) + .with_batch_size(task_ctx.session_config().batch_size()) + .with_fetch(fetch) + .with_reservation(reservation) + .build()?; let mut merged = common::collect(merge_stream).await.unwrap(); diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 9e6618dd1af58..ad640d8e8470d 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -49,49 +49,120 @@ macro_rules! merge_helper { }}; } -/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions -/// while preserving order. -pub fn streaming_merge( +#[derive(Default)] +pub struct StreamingMergeBuilder<'a> { streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, + schema: Option, + expressions: &'a [PhysicalSortExpr], + metrics: Option, + batch_size: Option, fetch: Option, - reservation: MemoryReservation, -) -> Result { - // If there are no sort expressions, preserving the order - // doesn't mean anything (and result in infinite loops) - if expressions.is_empty() { - return internal_err!("Sort expressions cannot be empty for streaming merge"); + reservation: Option, +} + +impl<'a> StreamingMergeBuilder<'a> { + pub fn new() -> Self { + Self::default() } - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - _ => {} - } + + pub fn with_streams(mut self, streams: Vec) -> Self { + self.streams = streams; + self } - let streams = RowCursorStream::try_new( - schema.as_ref(), - expressions, - streams, - reservation.new_empty(), - )?; - - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - fetch, - reservation, - ))) + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn with_expressions(mut self, expressions: &'a [PhysicalSortExpr]) -> Self { + self.expressions = expressions; + self + } + + pub fn with_metrics(mut self, metrics: BaselineMetrics) -> Self { + self.metrics = Some(metrics); + self + } + + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } + + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self { + self.reservation = Some(reservation); + self + } + + pub fn build(self) -> Result { + let Self { + streams, + schema, + metrics, + batch_size, + reservation, + fetch, + expressions, + } = self; + + // Early return if streams or expressions are empty + let checks = [ + ( + streams.is_empty(), + "Streams cannot be empty for streaming merge", + ), + ( + expressions.is_empty(), + "Sort expressions cannot be empty for streaming merge", + ), + ]; + + if let Some((_, error_message)) = checks.iter().find(|(condition, _)| *condition) + { + return internal_err!("{}", error_message); + } + + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + metrics, + batch_size, + fetch, + reservation, + ))) + } } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index faeb4799f5aff..ec4c9dd502a60 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -56,7 +56,7 @@ pub(crate) struct ReceiverStreamBuilder { } impl ReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(capacity: usize) -> Self { let (tx, rx) = tokio::sync::mpsc::channel(capacity); @@ -83,10 +83,10 @@ impl ReceiverStreamBuilder { } /// Spawn a blocking task that will be aborted if this builder (or the stream - /// built from it) are dropped + /// built from it) are dropped. /// - /// this is often used to spawn tasks that write to the sender - /// retrieved from `Self::tx` + /// This is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx`. pub fn spawn_blocking(&mut self, f: F) where F: FnOnce() -> Result<()>, @@ -103,7 +103,7 @@ impl ReceiverStreamBuilder { mut join_set, } = self; - // don't need tx + // Doesn't need tx drop(tx); // future that checks the result of the join set, and propagates panic if seen @@ -112,7 +112,7 @@ impl ReceiverStreamBuilder { match result { Ok(task_result) => { match task_result { - // nothing to report + // Nothing to report Ok(_) => continue, // This means a blocking task error Err(error) => return Some(Err(error)), @@ -215,7 +215,7 @@ pub struct RecordBatchReceiverStreamBuilder { } impl RecordBatchReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { Self { schema, @@ -256,7 +256,7 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn_blocking(f) } - /// runs the `partition` of the `input` ExecutionPlan on the + /// Runs the `partition` of the `input` ExecutionPlan on the /// tokio threadpool and writes its outputs to this stream /// /// If the input partition produces an error, the error will be @@ -299,7 +299,7 @@ impl RecordBatchReceiverStreamBuilder { return Ok(()); } - // stop after the first error is encontered (don't + // Stop after the first error is encountered (Don't // drive all streams to completion) if is_err { debug!( @@ -437,12 +437,12 @@ impl ObservedStream { } impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.inner.schema() } } -impl futures::Stream for ObservedStream { +impl Stream for ObservedStream { type Item = Result; fn poll_next( @@ -483,13 +483,13 @@ mod test { async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { let schema = schema(); - // make 2 partitions, second partition panics before the first + // Make 2 partitions, second partition panics before the first let num_partitions = 2; let input = PanicExec::new(Arc::clone(&schema), num_partitions) .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) - // ensure that the panic results in an early shutdown (that + // Ensure that the panic results in an early shutdown (that // everything stops after the first panic). // Since the stream reads every other batch: (0,1,0,1,0,panic) @@ -512,10 +512,10 @@ mod test { builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx)); let stream = builder.build(); - // input should still be present + // Input should still be present assert!(std::sync::Weak::strong_count(&refs) > 0); - // drop the stream, ensure the refs go to zero + // Drop the stream, ensure the refs go to zero drop(stream); assert_strong_count_converges_to_zero(refs).await; } @@ -539,7 +539,7 @@ mod test { builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx)); let mut stream = builder.build(); - // get the first result, which should be an error + // Get the first result, which should be an error let first_batch = stream.next().await.unwrap(); let first_err = first_batch.unwrap_err(); assert_eq!(first_err.strip_backtrace(), "Execution error: Test1"); @@ -570,7 +570,7 @@ mod test { } let mut stream = builder.build(); - // drain the stream until it is complete, panic'ing on error + // Drain the stream until it is complete, panic'ing on error let mut num_batches = 0; while let Some(next) = stream.next().await { next.unwrap(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index b02e4fb5738d3..cdb94af1fe8a7 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -163,7 +163,7 @@ impl StreamingTableExec { } } -impl std::fmt::Debug for StreamingTableExec { +impl Debug for StreamingTableExec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LazyMemTableExec").finish_non_exhaustive() } @@ -295,7 +295,7 @@ mod test { #[tokio::test] async fn test_no_limit() { let exec = TestBuilder::new() - // make 2 batches, each with 100 rows + // Make 2 batches, each with 100 rows .with_batches(vec![make_partition(100), make_partition(100)]) .build(); @@ -306,9 +306,9 @@ mod test { #[tokio::test] async fn test_limit() { let exec = TestBuilder::new() - // make 2 batches, each with 100 rows + // Make 2 batches, each with 100 rows .with_batches(vec![make_partition(100), make_partition(100)]) - // limit to only the first 75 rows back + // Limit to only the first 75 rows back .with_limit(Some(75)) .build(); diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 4da43b313403a..90ec9b1068501 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -65,7 +65,7 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } -/// returns record batch with 3 columns of i32 in memory +/// Returns record batch with 3 columns of i32 in memory pub fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -88,7 +88,7 @@ pub fn build_table_i32( .unwrap() } -/// returns memory table scan wrapped around record batch with 3 columns of i32 +/// Returns memory table scan wrapped around record batch with 3 columns of i32 pub fn build_table_scan_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -125,7 +125,7 @@ pub fn mem_exec(partitions: usize) -> MemoryExec { MemoryExec::try_new(&data, schema, projection).unwrap() } -// construct a stream partition for test purposes +// Construct a stream partition for test purposes #[derive(Debug)] pub struct TestPartitionStream { pub schema: SchemaRef, diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index d3f1a4fd96caf..9b46ad2ec7b14 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -21,6 +21,7 @@ use arrow::{ compute::interleave, row::{RowConverter, Rows, SortField}, }; +use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use arrow_array::{Array, ArrayRef, RecordBatch}; @@ -225,7 +226,7 @@ impl TopK { /// return the size of memory used by this operator, in bytes fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.row_converter.size() + self.scratch_rows.size() + self.heap.size() @@ -444,8 +445,8 @@ impl TopKHeap { /// return the size of memory used by this heap, in bytes fn size(&self) -> usize { - std::mem::size_of::() - + (self.inner.capacity() * std::mem::size_of::()) + size_of::() + + (self.inner.capacity() * size_of::()) + self.store.size() + self.owned_bytes } @@ -636,9 +637,8 @@ impl RecordBatchStore { /// returns the size of memory used by this store, including all /// referenced `RecordBatch`es, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() - + self.batches.capacity() - * (std::mem::size_of::() + std::mem::size_of::()) + size_of::() + + self.batches.capacity() * (size_of::() + size_of::()) + self.batches_size } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 78b25686054d8..433dda870defa 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -468,26 +468,41 @@ pub fn can_interleave>>( } fn union_schema(inputs: &[Arc]) -> SchemaRef { - let fields: Vec = (0..inputs[0].schema().fields().len()) + let first_schema = inputs[0].schema(); + + let fields = (0..first_schema.fields().len()) .map(|i| { inputs .iter() - .filter_map(|input| { - if input.schema().fields().len() > i { - Some(input.schema().field(i).clone()) - } else { - None - } + .enumerate() + .map(|(input_idx, input)| { + let field = input.schema().field(i).clone(); + let mut metadata = field.metadata().clone(); + + let other_metadatas = inputs + .iter() + .enumerate() + .filter(|(other_idx, _)| *other_idx != input_idx) + .flat_map(|(_, other_input)| { + other_input.schema().field(i).metadata().clone().into_iter() + }); + + metadata.extend(other_metadatas); + field.with_metadata(metadata) }) - .find_or_first(|f| f.is_nullable()) + .find_or_first(Field::is_nullable) + // We can unwrap this because if inputs was empty, this would've already panic'ed when we + // indexed into inputs[0]. .unwrap() }) + .collect::>(); + + let all_metadata_merged = inputs + .iter() + .flat_map(|i| i.schema().metadata().clone().into_iter()) .collect(); - Arc::new(Schema::new_with_metadata( - fields, - inputs[0].schema().metadata().clone(), - )) + Arc::new(Schema::new_with_metadata(fields, all_metadata_merged)) } /// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one @@ -800,11 +815,11 @@ mod tests { .collect::>(); let child1 = Arc::new( MemoryExec::try_new(&[], Arc::clone(&schema), None)? - .with_sort_information(first_orderings), + .try_with_sort_information(first_orderings)?, ); let child2 = Arc::new( MemoryExec::try_new(&[], Arc::clone(&schema), None)? - .with_sort_information(second_orderings), + .try_with_sort_information(second_orderings)?, ); let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 50af6b4960a50..3e312b7451bef 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -62,9 +62,9 @@ pub struct UnnestExec { input: Arc, /// The schema once the unnest is applied schema: SchemaRef, - /// indices of the list-typed columns in the input schema + /// Indices of the list-typed columns in the input schema list_column_indices: Vec, - /// indices of the struct-typed columns in the input schema + /// Indices of the struct-typed columns in the input schema struct_column_indices: Vec, /// Options options: UnnestOptions, @@ -115,12 +115,12 @@ impl UnnestExec { &self.input } - /// indices of the list-typed columns in the input schema + /// Indices of the list-typed columns in the input schema pub fn list_column_indices(&self) -> &[ListUnnest] { &self.list_column_indices } - /// indices of the struct-typed columns in the input schema + /// Indices of the struct-typed columns in the input schema pub fn struct_column_indices(&self) -> &[usize] { &self.struct_column_indices } @@ -203,7 +203,7 @@ impl ExecutionPlan for UnnestExec { #[derive(Clone, Debug)] struct UnnestMetrics { - /// total time for column unnesting + /// Total time for column unnesting elapsed_compute: metrics::Time, /// Number of batches consumed input_batches: metrics::Count, @@ -411,7 +411,7 @@ fn list_unnest_at_level( level_to_unnest: usize, options: &UnnestOptions, ) -> Result<(Vec, usize)> { - // extract unnestable columns at this level + // Extract unnestable columns at this level let (arrs_to_unnest, list_unnest_specs): (Vec>, Vec<_>) = list_type_unnests .iter() @@ -422,7 +422,7 @@ fn list_unnest_at_level( *unnesting, )); } - // this means the unnesting on this item has started at higher level + // This means the unnesting on this item has started at higher level // and need to continue until depth reaches 1 if level_to_unnest < unnesting.depth { return Some(( @@ -434,7 +434,7 @@ fn list_unnest_at_level( }) .unzip(); - // filter out so that list_arrays only contain column with the highest depth + // Filter out so that list_arrays only contain column with the highest depth // at the same time, during iteration remove this depth so next time we don't have to unnest them again let longest_length = find_longest_length(&arrs_to_unnest, options)?; let unnested_length = longest_length.as_primitive::(); @@ -456,7 +456,7 @@ fn list_unnest_at_level( // Create the take indices array for other columns let take_indices = create_take_indicies(unnested_length, total_length); - // dimension of arrays in batch is untouch, but the values are repeated + // Dimension of arrays in batch is untouched, but the values are repeated // as the side effect of unnesting let ret = repeat_arrs_from_indices(batch, &take_indices)?; unnested_temp_arrays @@ -548,8 +548,8 @@ fn build_batch( // This arr always has the same column count with the input batch let mut flatten_arrs = vec![]; - // original batch has the same columns - // all unnesting results are written to temp_batch + // Original batch has the same columns + // All unnesting results are written to temp_batch for depth in (1..=max_recursion).rev() { let input = match depth == max_recursion { true => batch.columns(), @@ -593,11 +593,11 @@ fn build_batch( .map(|(order, unnest_def)| (*unnest_def, order)) .collect(); - // one original column may be unnested multiple times into separate columns + // One original column may be unnested multiple times into separate columns let mut multi_unnested_per_original_index = unnested_array_map .into_iter() .map( - // each item in unnested_columns is the result of unnesting the same input column + // Each item in unnested_columns is the result of unnesting the same input column // we need to sort them to conform with the original expression order // e.g unnest(unnest(col)) must goes before unnest(col) |(original_index, mut unnested_columns)| { @@ -636,7 +636,7 @@ fn build_batch( .into_iter() .enumerate() .flat_map(|(col_idx, arr)| { - // convert original column into its unnested version(s) + // Convert original column into its unnested version(s) // Plural because one column can be unnested with different recursion level // and into separate output columns match multi_unnested_per_original_index.remove(&col_idx) { @@ -905,12 +905,10 @@ fn repeat_arrs_from_indices( #[cfg(test)] mod tests { use super::*; - use arrow::{ - datatypes::{Field, Int32Type}, - util::pretty::pretty_format_batches, - }; + use arrow::datatypes::{Field, Int32Type}; use arrow_array::{GenericListArray, OffsetSizeTrait, StringArray}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + use datafusion_common::assert_batches_eq; // Create a GenericListArray with the following list values: // [A, B, C], [], NULL, [D], NULL, [NULL, F] @@ -986,7 +984,7 @@ mod tests { list_array: &dyn ListArrayType, lengths: Vec, expected: Vec>, - ) -> datafusion_common::Result<()> { + ) -> Result<()> { let length_array = Int64Array::from(lengths); let unnested_array = unnest_list_array(list_array, &length_array, 3 * 6)?; let strs = unnested_array.as_string::().iter().collect::>(); @@ -995,7 +993,7 @@ mod tests { } #[test] - fn test_build_batch_list_arr_recursive() -> datafusion_common::Result<()> { + fn test_build_batch_list_arr_recursive() -> Result<()> { // col1 | col2 // [[1,2,3],null,[4,5]] | ['a','b'] // [[7,8,9,10], null, [11,12,13]] | ['c','d'] @@ -1092,43 +1090,42 @@ mod tests { &HashSet::default(), &UnnestOptions { preserve_nulls: true, + recursions: vec![], }, )?; - let actual = - format!("{}", pretty_format_batches(vec![ret].as_ref())?).to_lowercase(); - let expected = r#" -+---------------------------------+---------------------------------+---------------------------------+ -| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 | -+---------------------------------+---------------------------------+---------------------------------+ -| [1, 2, 3] | 1 | a | -| | 2 | b | -| [4, 5] | 3 | | -| [1, 2, 3] | | a | -| | | b | -| [4, 5] | | | -| [1, 2, 3] | 4 | a | -| | 5 | b | -| [4, 5] | | | -| [7, 8, 9, 10] | 7 | c | -| | 8 | d | -| [11, 12, 13] | 9 | | -| | 10 | | -| [7, 8, 9, 10] | | c | -| | | d | -| [11, 12, 13] | | | -| [7, 8, 9, 10] | 11 | c | -| | 12 | d | -| [11, 12, 13] | 13 | | -| | | e | -+---------------------------------+---------------------------------+---------------------------------+ - "# - .trim(); - assert_eq!(actual, expected); + + let expected = &[ +"+---------------------------------+---------------------------------+---------------------------------+", +"| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 |", +"+---------------------------------+---------------------------------+---------------------------------+", +"| [1, 2, 3] | 1 | a |", +"| | 2 | b |", +"| [4, 5] | 3 | |", +"| [1, 2, 3] | | a |", +"| | | b |", +"| [4, 5] | | |", +"| [1, 2, 3] | 4 | a |", +"| | 5 | b |", +"| [4, 5] | | |", +"| [7, 8, 9, 10] | 7 | c |", +"| | 8 | d |", +"| [11, 12, 13] | 9 | |", +"| | 10 | |", +"| [7, 8, 9, 10] | | c |", +"| | | d |", +"| [11, 12, 13] | | |", +"| [7, 8, 9, 10] | 11 | c |", +"| | 12 | d |", +"| [11, 12, 13] | 13 | |", +"| | | e |", +"+---------------------------------+---------------------------------+---------------------------------+", + ]; + assert_batches_eq!(expected, &[ret]); Ok(()) } #[test] - fn test_unnest_list_array() -> datafusion_common::Result<()> { + fn test_unnest_list_array() -> Result<()> { // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = make_generic_array::(); verify_unnest_list_array( @@ -1176,8 +1173,11 @@ mod tests { list_arrays: &[ArrayRef], preserve_nulls: bool, expected: Vec, - ) -> datafusion_common::Result<()> { - let options = UnnestOptions { preserve_nulls }; + ) -> Result<()> { + let options = UnnestOptions { + preserve_nulls, + recursions: vec![], + }; let longest_length = find_longest_length(list_arrays, &options)?; let expected_array = Int64Array::from(expected); assert_eq!( @@ -1191,7 +1191,7 @@ mod tests { } #[test] - fn test_longest_list_length() -> datafusion_common::Result<()> { + fn test_longest_list_length() -> Result<()> { // Test with single ListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; @@ -1223,7 +1223,7 @@ mod tests { } #[test] - fn test_create_take_indicies() -> datafusion_common::Result<()> { + fn test_create_take_indicies() -> Result<()> { let length_array = Int64Array::from(vec![2, 3, 1]); let take_indicies = create_take_indicies(&length_array, 6); let expected = Int64Array::from(vec![0, 0, 1, 1, 1, 2]); diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index e01aea1fdd6bc..991146d245a70 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -47,7 +47,7 @@ pub struct ValuesExec { } impl ValuesExec { - /// create a new values exec from data as expr + /// Create a new values exec from data as expr pub fn try_new( schema: SchemaRef, data: Vec>>, @@ -57,7 +57,7 @@ impl ValuesExec { } let n_row = data.len(); let n_col = schema.fields().len(); - // we have this single row batch as a placeholder to satisfy evaluation argument + // We have this single row batch as a placeholder to satisfy evaluation argument // and generate a single output row let batch = RecordBatch::try_new_with_options( Arc::new(Schema::empty()), @@ -126,7 +126,7 @@ impl ValuesExec { }) } - /// provides the data + /// Provides the data pub fn data(&self) -> Vec { self.data.clone() } @@ -219,6 +219,7 @@ mod tests { use crate::test::{self, make_partition}; use arrow_schema::{DataType, Field}; + use datafusion_common::stats::{ColumnStatistics, Precision}; #[tokio::test] async fn values_empty_case() -> Result<()> { @@ -269,4 +270,34 @@ mod tests { let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) .unwrap_err(); } + + #[test] + fn values_stats_with_nulls_only() -> Result<()> { + let data = vec![ + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + ]; + let rows = data.len(); + let values = ValuesExec::try_new( + Arc::new(Schema::new(vec![Field::new("col0", DataType::Null, true)])), + data, + )?; + + assert_eq!( + values.statistics()?, + Statistics { + num_rows: Precision::Exact(rows), + total_byte_size: Precision::Exact(8), // not important + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(rows), // there are only nulls + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + },], + } + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 001e134581c03..6495657339fa9 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -40,17 +40,17 @@ use crate::{ SendableRecordBatchStream, Statistics, WindowExpr, }; use ahash::RandomState; +use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, - compute::{concat, concat_batches, sort_to_indices}, + compute::{concat, concat_batches, sort_to_indices, take_arrays}, datatypes::SchemaRef, record_batch::RecordBatch, }; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::utils::{ - evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, - get_record_batch_at_indices, get_row_at_idx, + evaluate_partition_ranges, get_at_indices, get_row_at_idx, }; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -257,17 +257,11 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.input_order_mode != InputOrderMode::Sorted - || self.ordered_partition_by_indices.len() >= partition_bys.len() - { - let partition_bys = self - .ordered_partition_by_indices - .iter() - .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys)] - } else { - vec![calc_requirements(partition_bys, order_keys)] - } + let partition_bys = self + .ordered_partition_by_indices + .iter() + .map(|idx| &partition_bys[*idx]); + vec![calc_requirements(partition_bys, order_keys)] } fn required_input_distribution(&self) -> Vec { @@ -542,7 +536,9 @@ impl PartitionSearcher for LinearSearch { // We should emit columns according to row index ordering. let sorted_indices = sort_to_indices(&all_indices, None, None)?; // Construct new column according to row ordering. This fixes ordering - get_arrayref_at_indices(&new_columns, &sorted_indices).map(Some) + take_arrays(&new_columns, &sorted_indices, None) + .map(Some) + .map_err(|e| arrow_datafusion_err!(e)) } fn evaluate_partition_batches( @@ -562,7 +558,7 @@ impl PartitionSearcher for LinearSearch { let mut new_indices = UInt32Builder::with_capacity(indices.len()); new_indices.append_slice(&indices); let indices = new_indices.finish(); - Ok((row, get_record_batch_at_indices(record_batch, &indices)?)) + Ok((row, take_record_batch(record_batch, &indices)?)) }) .collect() } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6e1cb8db5f09e..7ebb7e71ec57b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,21 +21,16 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{ - cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, - PhysicalSortExpr, - }, + expressions::{Literal, NthValue, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::equivalence::collapse_lex_req; @@ -51,7 +46,9 @@ mod utils; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::expressions::Column; pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, @@ -120,7 +117,8 @@ pub fn create_window_expr( .schema(Arc::new(input_schema.clone())) .alias(name) .with_ignore_nulls(ignore_nulls) - .build()?; + .build() + .map(Arc::new)?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -130,7 +128,7 @@ pub fn create_window_expr( } // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( - create_udwf_window_expr(fun, args, input_schema, name)?, + create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -143,7 +141,7 @@ fn window_expr_from_aggregate_expr( partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, - aggregate: AggregateFunctionExpr, + aggregate: Arc, ) -> Arc { // Is there a potentially unlimited sized window frame? let unbounded_window = window_frame.start_bound.is_unbounded(); @@ -165,52 +163,16 @@ fn window_expr_from_aggregate_expr( } } -fn get_scalar_value_from_args( - args: &[Arc], - index: usize, -) -> Result> { - Ok(if let Some(field) = args.get(index) { - let tmp = field - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::NotImplemented( - format!("There is only support Literal types for field at idx: {index} in Window Function"), - ))? - .value() - .clone(); - Some(tmp) - } else { - None - }) -} - fn get_signed_integer(value: ScalarValue) -> Result { - if !value.data_type().is_integer() { - return Err(DataFusionError::Execution( - "Expected an integer value".to_string(), - )); + if value.is_null() { + return Ok(0); } - value.cast_to(&DataType::Int64)?.try_into() -} -fn get_unsigned_integer(value: ScalarValue) -> Result { if !value.data_type().is_integer() { - return Err(DataFusionError::Execution( - "Expected an integer value".to_string(), - )); + return exec_err!("Expected an integer value"); } - value.cast_to(&DataType::UInt64)?.try_into() -} -fn get_casted_value( - default_value: Option, - dtype: &DataType, -) -> Result { - match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), - // If None or Null datatype - _ => ScalarValue::try_from(dtype), - } + value.cast_to(&DataType::Int64)?.try_into() } fn create_built_in_window_expr( @@ -224,64 +186,6 @@ fn create_built_in_window_expr( let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), - BuiltInWindowFunction::Ntile => { - let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires a positive integer".to_string(), - ) - })?; - - if n.is_null() { - return exec_err!("NTILE requires a positive integer, but finds NULL"); - } - - if n.is_unsigned() { - let n = get_unsigned_integer(n)?; - Arc::new(Ntile::new(name, n, out_data_type)) - } else { - let n: i64 = get_signed_integer(n)?; - if n <= 0 { - return exec_err!("NTILE requires a positive integer"); - } - Arc::new(Ntile::new(name, n as u64, out_data_type)) - } - } - BuiltInWindowFunction::Lag => { - let arg = Arc::clone(&args[0]); - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; - Arc::new(lag( - name, - out_data_type.clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } - BuiltInWindowFunction::Lead => { - let arg = Arc::clone(&args[0]); - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; - Arc::new(lead( - name, - out_data_type.clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } BuiltInWindowFunction::NthValue => { let arg = Arc::clone(&args[0]); let n = get_signed_integer( @@ -329,6 +233,7 @@ fn create_udwf_window_expr( args: &[Arc], input_schema: &Schema, name: String, + ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason let input_types: Vec<_> = args @@ -341,6 +246,8 @@ fn create_udwf_window_expr( args: args.to_vec(), input_types, name, + is_reversed: false, + ignore_nulls, })) } @@ -353,6 +260,12 @@ struct WindowUDFExpr { name: String, /// Types of input expressions input_types: Vec, + /// This is set to `true` only if the user-defined window function + /// expression supports evaluation in reverse order, and the + /// evaluation order is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is defined, `false` otherwise. + ignore_nulls: bool, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -366,11 +279,18 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn expressions(&self) -> Vec> { - self.args.clone() + self.fun + .expressions(ExpressionArgs::new(&self.args, &self.input_types)) } fn create_evaluator(&self) -> Result> { - self.fun.partition_evaluator_factory() + self.fun + .partition_evaluator_factory(PartitionEvaluatorArgs::new( + &self.args, + &self.input_types, + self.is_reversed, + self.ignore_nulls, + )) } fn name(&self) -> &str { @@ -378,7 +298,18 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn reverse_expr(&self) -> Option> { - None + match self.fun.reverse_expr() { + ReversedUDWF::Identical => Some(Arc::new(self.clone())), + ReversedUDWF::NotSupported => None, + ReversedUDWF::Reversed(fun) => Some(Arc::new(WindowUDFExpr { + fun, + args: self.args.clone(), + name: self.name.clone(), + input_types: self.input_types.clone(), + is_reversed: !self.is_reversed, + ignore_nulls: self.ignore_nulls, + })), + } } fn get_result_ordering(&self, schema: &SchemaRef) -> Option { diff --git a/datafusion/physical-plan/src/windows/utils.rs b/datafusion/physical-plan/src/windows/utils.rs index 3cf92daae0fb2..13332ea82fa14 100644 --- a/datafusion/physical-plan/src/windows/utils.rs +++ b/datafusion/physical-plan/src/windows/utils.rs @@ -31,5 +31,7 @@ pub(crate) fn create_schema( for expr in window_expr { builder.push(expr.field()?); } - Ok(builder.finish()) + Ok(builder + .finish() + .with_metadata(input_schema.metadata().clone())) } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index ba95640a87c7f..61d444171cc72 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -225,31 +225,31 @@ mod tests { #[test] fn test_work_table() { let work_table = WorkTable::new(); - // can't take from empty work_table + // Can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); - // update batch to work_table + // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap(); reservation.try_grow(100).unwrap(); work_table.update(ReservedBatches::new(vec![batch.clone()], reservation)); - // take from work_table + // Take from work_table let reserved_batches = work_table.take().unwrap(); assert_eq!(reserved_batches.batches, vec![batch.clone()]); - // consume the batch by the MemoryStream + // Consume the batch by the MemoryStream let memory_stream = MemoryStream::try_new(reserved_batches.batches, batch.schema(), None) .unwrap() .with_reservation(reserved_batches.reservation); - // should still be reserved + // Should still be reserved assert_eq!(pool.reserved(), 100); - // the reservation should be freed after drop the memory_stream + // The reservation should be freed after drop the memory_stream drop(memory_stream); assert_eq!(pool.reserved(), 0); } diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 5051c8f9322ff..6c53e1b1ced0c 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 0914669f82fa8..6e5783f467a70 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index d1506fcd64f09..7f8bce6b206e3 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -494,6 +494,7 @@ message ParquetOptions { bool bloom_filter_on_read = 26; // default = true bool bloom_filter_on_write = 27; // default = false bool schema_force_view_types = 28; // default = false + bool binary_as_string = 29; // default = false oneof metadata_size_hint_opt { uint64 metadata_size_hint = 4; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index d1b4374fc0e71..d848f795c6841 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -897,7 +897,7 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { pruning: value.pruning, skip_metadata: value.skip_metadata, metadata_size_hint: value - .metadata_size_hint_opt.clone() + .metadata_size_hint_opt .map(|opt| match opt { protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => Some(v as usize), }) @@ -958,6 +958,7 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as usize, maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as usize, schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, }) } } @@ -979,7 +980,7 @@ impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { }) .unwrap_or(None), max_statistics_size: value - .max_statistics_size_opt.clone() + .max_statistics_size_opt .map(|opt| match opt { protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(v as usize), }) @@ -990,18 +991,18 @@ impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { protobuf::parquet_column_options::EncodingOpt::Encoding(v) => Some(v), }) .unwrap_or(None), - bloom_filter_enabled: value.bloom_filter_enabled_opt.clone().map(|opt| match opt { + bloom_filter_enabled: value.bloom_filter_enabled_opt.map(|opt| match opt { protobuf::parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => Some(v), }) .unwrap_or(None), bloom_filter_fpp: value - .bloom_filter_fpp_opt.clone() + .bloom_filter_fpp_opt .map(|opt| match opt { protobuf::parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), }) .unwrap_or(None), bloom_filter_ndv: value - .bloom_filter_ndv_opt.clone() + .bloom_filter_ndv_opt .map(|opt| match opt { protobuf::parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), }) diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index fa5d1f442754d..e8b46fbf7012f 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1548,18 +1548,22 @@ impl serde::Serialize for CsvOptions { let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("hasHeader", pbjson::private::base64::encode(&self.has_header).as_str())?; } if !self.delimiter.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("delimiter", pbjson::private::base64::encode(&self.delimiter).as_str())?; } if !self.quote.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("quote", pbjson::private::base64::encode(&self.quote).as_str())?; } if !self.escape.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("escape", pbjson::private::base64::encode(&self.escape).as_str())?; } if self.compression != 0 { @@ -1569,6 +1573,7 @@ impl serde::Serialize for CsvOptions { } if self.schema_infer_max_rec != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; } if !self.date_format.is_empty() { @@ -1591,18 +1596,22 @@ impl serde::Serialize for CsvOptions { } if !self.comment.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?; } if !self.double_quote.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; } if !self.newlines_in_values.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("newlinesInValues", pbjson::private::base64::encode(&self.newlines_in_values).as_str())?; } if !self.terminator.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("terminator", pbjson::private::base64::encode(&self.terminator).as_str())?; } struct_ser.end() @@ -2276,14 +2285,17 @@ impl serde::Serialize for Decimal128 { let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } if self.p != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } if self.s != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() @@ -2410,14 +2422,17 @@ impl serde::Serialize for Decimal256 { let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } if self.p != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } if self.s != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() @@ -3080,6 +3095,7 @@ impl serde::Serialize for Field { } if self.dict_id != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; } if self.dict_ordered { @@ -3484,6 +3500,7 @@ impl serde::Serialize for IntervalMonthDayNanoValue { } if self.nanos != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; } struct_ser.end() @@ -3917,6 +3934,7 @@ impl serde::Serialize for JsonOptions { } if self.schema_infer_max_rec != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; } struct_ser.end() @@ -4474,6 +4492,7 @@ impl serde::Serialize for ParquetColumnOptions { match v { parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; } } @@ -4894,6 +4913,9 @@ impl serde::Serialize for ParquetOptions { if self.schema_force_view_types { len += 1; } + if self.binary_as_string { + len += 1; + } if self.dictionary_page_size_limit != 0 { len += 1; } @@ -4951,10 +4973,12 @@ impl serde::Serialize for ParquetOptions { } if self.data_pagesize_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dataPagesizeLimit", ToString::to_string(&self.data_pagesize_limit).as_str())?; } if self.write_batch_size != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; } if !self.writer_version.is_empty() { @@ -4965,10 +4989,12 @@ impl serde::Serialize for ParquetOptions { } if self.maximum_parallel_row_group_writers != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maximumParallelRowGroupWriters", ToString::to_string(&self.maximum_parallel_row_group_writers).as_str())?; } if self.maximum_buffered_record_batches_per_stream != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maximumBufferedRecordBatchesPerStream", ToString::to_string(&self.maximum_buffered_record_batches_per_stream).as_str())?; } if self.bloom_filter_on_read { @@ -4980,16 +5006,22 @@ impl serde::Serialize for ParquetOptions { if self.schema_force_view_types { struct_ser.serialize_field("schemaForceViewTypes", &self.schema_force_view_types)?; } + if self.binary_as_string { + struct_ser.serialize_field("binaryAsString", &self.binary_as_string)?; + } if self.dictionary_page_size_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; } if self.data_page_row_count_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; } if self.max_row_group_size != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; } if !self.created_by.is_empty() { @@ -4999,6 +5031,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("metadataSizeHint", ToString::to_string(&v).as_str())?; } } @@ -5028,6 +5061,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maxStatisticsSize", ToString::to_string(&v).as_str())?; } } @@ -5036,6 +5070,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("columnIndexTruncateLength", ToString::to_string(&v).as_str())?; } } @@ -5058,6 +5093,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; } } @@ -5099,6 +5135,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "bloomFilterOnWrite", "schema_force_view_types", "schemaForceViewTypes", + "binary_as_string", + "binaryAsString", "dictionary_page_size_limit", "dictionaryPageSizeLimit", "data_page_row_count_limit", @@ -5140,7 +5178,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { MaximumBufferedRecordBatchesPerStream, BloomFilterOnRead, BloomFilterOnWrite, - schemaForceViewTypes, + SchemaForceViewTypes, + BinaryAsString, DictionaryPageSizeLimit, DataPageRowCountLimit, MaxRowGroupSize, @@ -5188,7 +5227,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "maximumBufferedRecordBatchesPerStream" | "maximum_buffered_record_batches_per_stream" => Ok(GeneratedField::MaximumBufferedRecordBatchesPerStream), "bloomFilterOnRead" | "bloom_filter_on_read" => Ok(GeneratedField::BloomFilterOnRead), "bloomFilterOnWrite" | "bloom_filter_on_write" => Ok(GeneratedField::BloomFilterOnWrite), - "schemaForceViewTypes" | "schema_force_view_types" => Ok(GeneratedField::schemaForceViewTypes), + "schemaForceViewTypes" | "schema_force_view_types" => Ok(GeneratedField::SchemaForceViewTypes), + "binaryAsString" | "binary_as_string" => Ok(GeneratedField::BinaryAsString), "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), @@ -5235,6 +5275,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { let mut bloom_filter_on_read__ = None; let mut bloom_filter_on_write__ = None; let mut schema_force_view_types__ = None; + let mut binary_as_string__ = None; let mut dictionary_page_size_limit__ = None; let mut data_page_row_count_limit__ = None; let mut max_row_group_size__ = None; @@ -5336,12 +5377,18 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { } bloom_filter_on_write__ = Some(map_.next_value()?); } - GeneratedField::schemaForceViewTypes => { + GeneratedField::SchemaForceViewTypes => { if schema_force_view_types__.is_some() { return Err(serde::de::Error::duplicate_field("schemaForceViewTypes")); } schema_force_view_types__ = Some(map_.next_value()?); } + GeneratedField::BinaryAsString => { + if binary_as_string__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryAsString")); + } + binary_as_string__ = Some(map_.next_value()?); + } GeneratedField::DictionaryPageSizeLimit => { if dictionary_page_size_limit__.is_some() { return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); @@ -5443,6 +5490,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { bloom_filter_on_read: bloom_filter_on_read__.unwrap_or_default(), bloom_filter_on_write: bloom_filter_on_write__.unwrap_or_default(), schema_force_view_types: schema_force_view_types__.unwrap_or_default(), + binary_as_string: binary_as_string__.unwrap_or_default(), dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), max_row_group_size: max_row_group_size__.unwrap_or_default(), @@ -5867,6 +5915,7 @@ impl serde::Serialize for ScalarFixedSizeBinary { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarFixedSizeBinary", len)?; if !self.values.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("values", pbjson::private::base64::encode(&self.values).as_str())?; } if self.length != 0 { @@ -5986,10 +6035,12 @@ impl serde::Serialize for ScalarNestedValue { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } if !self.arrow_data.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } if let Some(v) = self.schema.as_ref() { @@ -6130,10 +6181,12 @@ impl serde::Serialize for scalar_nested_value::Dictionary { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue.Dictionary", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } if !self.arrow_data.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } struct_ser.end() @@ -6354,10 +6407,12 @@ impl serde::Serialize for ScalarTime64Value { match v { scalar_time64_value::Value::Time64MicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("time64MicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_time64_value::Value::Time64NanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("time64NanosecondValue", ToString::to_string(&v).as_str())?; } } @@ -6471,18 +6526,22 @@ impl serde::Serialize for ScalarTimestampValue { match v { scalar_timestamp_value::Value::TimeMicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeMicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeNanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeNanosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeSecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeSecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeMillisecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeMillisecondValue", ToString::to_string(&v).as_str())?; } } @@ -6645,6 +6704,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Int64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("int64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Uint8Value(v) => { @@ -6658,6 +6718,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Uint64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("uint64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Float32Value(v) => { @@ -6695,6 +6756,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Date64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::IntervalYearmonthValue(v) => { @@ -6702,18 +6764,22 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::DurationSecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationMillisecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationMicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationNanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::TimestampValue(v) => { @@ -6724,14 +6790,17 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::BinaryValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("binaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::LargeBinaryValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::BinaryViewValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("binaryViewValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::Time64Value(v) => { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index d6f982278d67d..939a4b3c2cd2a 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -1,11 +1,9 @@ // This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnRelation { #[prost(string, tag = "1")] pub relation: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Column { #[prost(string, tag = "1")] @@ -13,7 +11,6 @@ pub struct Column { #[prost(message, optional, tag = "2")] pub relation: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DfField { #[prost(message, optional, tag = "1")] @@ -21,7 +18,6 @@ pub struct DfField { #[prost(message, optional, tag = "2")] pub qualifier: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DfSchema { #[prost(message, repeated, tag = "1")] @@ -32,40 +28,33 @@ pub struct DfSchema { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvFormat { #[prost(message, optional, tag = "5")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetFormat { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct AvroFormat {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct NdJsonFormat { #[prost(message, optional, tag = "1")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UniqueConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Constraint { #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] @@ -73,7 +62,6 @@ pub struct Constraint { } /// Nested message and enum types in `Constraint`. pub mod constraint { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ConstraintMode { #[prost(message, tag = "1")] @@ -82,19 +70,15 @@ pub mod constraint { Unique(super::UniqueConstraint), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Constraints { #[prost(message, repeated, tag = "1")] pub constraints: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct AvroOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ArrowOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Schema { #[prost(message, repeated, tag = "1")] @@ -105,7 +89,6 @@ pub struct Schema { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { /// name of the field @@ -128,7 +111,6 @@ pub struct Field { #[prost(bool, tag = "7")] pub dict_ordered: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { #[prost(enumeration = "TimeUnit", tag = "1")] @@ -136,29 +118,25 @@ pub struct Timestamp { #[prost(string, tag = "2")] pub timezone: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Decimal { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Decimal256Type { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FixedSizeList { #[prost(message, optional, boxed, tag = "1")] @@ -166,7 +144,6 @@ pub struct FixedSizeList { #[prost(int32, tag = "2")] pub list_size: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Dictionary { #[prost(message, optional, boxed, tag = "1")] @@ -174,13 +151,11 @@ pub struct Dictionary { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Struct { #[prost(message, repeated, tag = "1")] pub sub_field_types: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Map { #[prost(message, optional, boxed, tag = "1")] @@ -188,7 +163,6 @@ pub struct Map { #[prost(bool, tag = "2")] pub keys_sorted: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] @@ -199,7 +173,6 @@ pub struct Union { pub type_ids: ::prost::alloc::vec::Vec, } /// Used for List/FixedSizeList/LargeList/Struct/Map -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { #[prost(bytes = "vec", tag = "1")] @@ -213,7 +186,6 @@ pub struct ScalarNestedValue { } /// Nested message and enum types in `ScalarNestedValue`. pub mod scalar_nested_value { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Dictionary { #[prost(bytes = "vec", tag = "1")] @@ -222,16 +194,14 @@ pub mod scalar_nested_value { pub arrow_data: ::prost::alloc::vec::Vec, } } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ScalarTime32Value { #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime32Value`. pub mod scalar_time32_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int32, tag = "1")] Time32SecondValue(i32), @@ -239,16 +209,14 @@ pub mod scalar_time32_value { Time32MillisecondValue(i32), } } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ScalarTime64Value { #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime64Value`. pub mod scalar_time64_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] Time64MicrosecondValue(i64), @@ -256,7 +224,6 @@ pub mod scalar_time64_value { Time64NanosecondValue(i64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarTimestampValue { #[prost(string, tag = "5")] @@ -266,8 +233,7 @@ pub struct ScalarTimestampValue { } /// Nested message and enum types in `ScalarTimestampValue`. pub mod scalar_timestamp_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] TimeMicrosecondValue(i64), @@ -279,7 +245,6 @@ pub mod scalar_timestamp_value { TimeMillisecondValue(i64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarDictionaryValue { #[prost(message, optional, tag = "1")] @@ -287,16 +252,14 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] pub days: i32, #[prost(int32, tag = "2")] pub milliseconds: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct IntervalMonthDayNanoValue { #[prost(int32, tag = "1")] pub months: i32, @@ -305,7 +268,6 @@ pub struct IntervalMonthDayNanoValue { #[prost(int64, tag = "3")] pub nanos: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionField { #[prost(int32, tag = "1")] @@ -313,7 +275,6 @@ pub struct UnionField { #[prost(message, optional, tag = "2")] pub field: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionValue { /// Note that a null union value must have one or more fields, so we @@ -327,7 +288,6 @@ pub struct UnionValue { #[prost(enumeration = "UnionMode", tag = "4")] pub mode: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] @@ -335,7 +295,6 @@ pub struct ScalarFixedSizeBinary { #[prost(int32, tag = "2")] pub length: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarValue { #[prost( @@ -346,7 +305,6 @@ pub struct ScalarValue { } /// Nested message and enum types in `ScalarValue`. pub mod scalar_value { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Value { /// was PrimitiveScalarType null_value = 19; @@ -434,7 +392,6 @@ pub mod scalar_value { UnionValue(::prost::alloc::boxed::Box), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Decimal128 { #[prost(bytes = "vec", tag = "1")] @@ -444,7 +401,6 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Decimal256 { #[prost(bytes = "vec", tag = "1")] @@ -455,7 +411,6 @@ pub struct Decimal256 { pub s: i64, } /// Serialized data type -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ArrowType { #[prost( @@ -466,7 +421,6 @@ pub struct ArrowType { } /// Nested message and enum types in `ArrowType`. pub mod arrow_type { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ArrowTypeEnum { /// arrow::Type::NA @@ -557,16 +511,13 @@ pub mod arrow_type { /// i32 Two = 2; /// } /// } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EmptyMessage {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct JsonWriterOptions { #[prost(enumeration = "CompressionTypeVariant", tag = "1")] pub compression: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvWriterOptions { /// Compression type @@ -604,7 +555,6 @@ pub struct CsvWriterOptions { pub double_quote: bool, } /// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvOptions { /// Indicates if the CSV has a header row @@ -657,8 +607,7 @@ pub struct CsvOptions { pub terminator: ::prost::alloc::vec::Vec, } /// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct JsonOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -667,7 +616,6 @@ pub struct JsonOptions { #[prost(uint64, tag = "2")] pub schema_infer_max_rec: u64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { #[prost(message, optional, tag = "1")] @@ -680,7 +628,6 @@ pub struct TableParquetOptions { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetColumnSpecificOptions { #[prost(string, tag = "1")] @@ -688,7 +635,6 @@ pub struct ParquetColumnSpecificOptions { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetColumnOptions { #[prost(oneof = "parquet_column_options::BloomFilterEnabledOpt", tags = "1")] @@ -722,56 +668,47 @@ pub struct ParquetColumnOptions { } /// Nested message and enum types in `ParquetColumnOptions`. pub mod parquet_column_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterEnabledOpt { #[prost(bool, tag = "1")] BloomFilterEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "2")] Encoding(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "3")] DictionaryEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "4")] Compression(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "5")] StatisticsEnabled(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterFppOpt { #[prost(double, tag = "6")] BloomFilterFpp(f64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "7")] BloomFilterNdv(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MaxStatisticsSizeOpt { #[prost(uint32, tag = "8")] MaxStatisticsSize(u32), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetOptions { /// Regular fields @@ -820,6 +757,9 @@ pub struct ParquetOptions { /// default = false #[prost(bool, tag = "28")] pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, #[prost(uint64, tag = "12")] pub dictionary_page_size_limit: u64, #[prost(uint64, tag = "18")] @@ -859,62 +799,52 @@ pub struct ParquetOptions { } /// Nested message and enum types in `ParquetOptions`. pub mod parquet_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MetadataSizeHintOpt { #[prost(uint64, tag = "4")] MetadataSizeHint(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "10")] Compression(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "11")] DictionaryEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "13")] StatisticsEnabled(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MaxStatisticsSizeOpt { #[prost(uint64, tag = "14")] MaxStatisticsSize(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum ColumnIndexTruncateLengthOpt { #[prost(uint64, tag = "17")] ColumnIndexTruncateLength(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "19")] Encoding(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterFppOpt { #[prost(double, tag = "21")] BloomFilterFpp(f64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "22")] BloomFilterNdv(u64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Precision { #[prost(enumeration = "PrecisionInfo", tag = "1")] @@ -922,7 +852,6 @@ pub struct Precision { #[prost(message, optional, tag = "2")] pub val: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Statistics { #[prost(message, optional, tag = "1")] @@ -932,7 +861,6 @@ pub struct Statistics { #[prost(message, repeated, tag = "3")] pub column_stats: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnStats { #[prost(message, optional, tag = "1")] @@ -963,14 +891,14 @@ impl JoinType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinType::Inner => "INNER", - JoinType::Left => "LEFT", - JoinType::Right => "RIGHT", - JoinType::Full => "FULL", - JoinType::Leftsemi => "LEFTSEMI", - JoinType::Leftanti => "LEFTANTI", - JoinType::Rightsemi => "RIGHTSEMI", - JoinType::Rightanti => "RIGHTANTI", + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1001,8 +929,8 @@ impl JoinConstraint { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinConstraint::On => "ON", - JoinConstraint::Using => "USING", + Self::On => "ON", + Self::Using => "USING", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1029,10 +957,10 @@ impl TimeUnit { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - TimeUnit::Second => "Second", - TimeUnit::Millisecond => "Millisecond", - TimeUnit::Microsecond => "Microsecond", - TimeUnit::Nanosecond => "Nanosecond", + Self::Second => "Second", + Self::Millisecond => "Millisecond", + Self::Microsecond => "Microsecond", + Self::Nanosecond => "Nanosecond", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1060,9 +988,9 @@ impl IntervalUnit { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - IntervalUnit::YearMonth => "YearMonth", - IntervalUnit::DayTime => "DayTime", - IntervalUnit::MonthDayNano => "MonthDayNano", + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1088,8 +1016,8 @@ impl UnionMode { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - UnionMode::Sparse => "sparse", - UnionMode::Dense => "dense", + Self::Sparse => "sparse", + Self::Dense => "dense", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1117,11 +1045,11 @@ impl CompressionTypeVariant { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - CompressionTypeVariant::Gzip => "GZIP", - CompressionTypeVariant::Bzip2 => "BZIP2", - CompressionTypeVariant::Xz => "XZ", - CompressionTypeVariant::Zstd => "ZSTD", - CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1149,8 +1077,8 @@ impl JoinSide { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinSide::LeftSide => "LEFT_SIDE", - JoinSide::RightSide => "RIGHT_SIDE", + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1176,9 +1104,9 @@ impl PrecisionInfo { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - PrecisionInfo::Exact => "EXACT", - PrecisionInfo::Inexact => "INEXACT", - PrecisionInfo::Absent => "ABSENT", + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", } } /// Creates an enum from field names used in the ProtoBuf definition. diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index ebb53ae7577cf..f9b8973e2d413 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -831,6 +831,7 @@ impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as u64, maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as u64, schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, }) } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index d65c6ccaa660b..9e4b331a01bfa 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,14 +27,11 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] -[lints] -workspace = true - [lib] name = "datafusion_proto" path = "src/lib.rs" diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index ea28ac86e8df8..aee8fac4a1209 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 91050b2346de7..387934847cfd7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -163,6 +163,7 @@ message CreateExternalTableNode { datafusion_common.DfSchema schema = 4; repeated string table_partition_cols = 5; bool if_not_exists = 6; + bool temporary = 14; string definition = 7; repeated SortExprNodeCollection order_exprs = 10; bool unbounded = 11; @@ -200,6 +201,7 @@ message CreateViewNode { TableReference name = 5; LogicalPlanNode input = 2; bool or_replace = 3; + bool temporary = 6; string definition = 4; } @@ -262,7 +264,7 @@ message CopyToNode { message UnnestNode { LogicalPlanNode input = 1; - repeated ColumnUnnestExec exec_columns = 2; + repeated datafusion_common.Column exec_columns = 2; repeated ColumnUnnestListItem list_type_columns = 3; repeated uint64 struct_type_columns = 4; repeated uint64 dependency_indices = 5; @@ -283,17 +285,15 @@ message ColumnUnnestListRecursion { uint32 depth = 2; } -message ColumnUnnestExec { - datafusion_common.Column column = 1; - oneof UnnestType { - ColumnUnnestListRecursions list = 2; - datafusion_common.EmptyMessage struct = 3; - datafusion_common.EmptyMessage inferred = 4; - } -} - message UnnestOptions { bool preserve_nulls = 1; + repeated RecursionUnnestOption recursions = 2; +} + +message RecursionUnnestOption { + datafusion_common.Column output_column = 1; + datafusion_common.Column input_column = 2; + uint32 depth = 3; } message UnionNode { @@ -508,13 +508,13 @@ message ScalarUDFExprNode { enum BuiltInWindowFunction { UNSPECIFIED = 0; // https://protobuf.dev/programming-guides/dos-donts/#unspecified-enum // ROW_NUMBER = 0; - RANK = 1; - DENSE_RANK = 2; - PERCENT_RANK = 3; - CUME_DIST = 4; - NTILE = 5; - LAG = 6; - LEAD = 7; + // RANK = 1; + // DENSE_RANK = 2; + // PERCENT_RANK = 3; + // CUME_DIST = 4; + // NTILE = 5; + // LAG = 6; + // LEAD = 7; FIRST_VALUE = 8; LAST_VALUE = 9; NTH_VALUE = 10; @@ -526,7 +526,7 @@ message WindowExprNode { string udaf = 3; string udwf = 9; } - LogicalExprNode expr = 4; + repeated LogicalExprNode exprs = 4; repeated LogicalExprNode partition_by = 5; repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; @@ -731,14 +731,21 @@ message PartitionColumn { message FileSinkConfig { reserved 6; // writer_mode + reserved 8; // was `overwrite` which has been superseded by `insert_op` string object_store_url = 1; repeated PartitionedFile file_groups = 2; repeated string table_paths = 3; datafusion_common.Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; - bool overwrite = 8; bool keep_partition_by_columns = 9; + InsertOp insert_op = 10; +} + +enum InsertOp { + Append = 0; + Overwrite = 1; + Replace = 2; } message JsonSink { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 16de2c7772415..939a4b3c2cd2a 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -757,6 +757,9 @@ pub struct ParquetOptions { /// default = false #[prost(bool, tag = "28")] pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, #[prost(uint64, tag = "12")] pub dictionary_page_size_limit: u64, #[prost(uint64, tag = "18")] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d4dba6dcaec43..39ad2ab9cb555 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1662,13 +1662,6 @@ impl serde::Serialize for BuiltInWindowFunction { { let variant = match self { Self::Unspecified => "UNSPECIFIED", - Self::Rank => "RANK", - Self::DenseRank => "DENSE_RANK", - Self::PercentRank => "PERCENT_RANK", - Self::CumeDist => "CUME_DIST", - Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1684,13 +1677,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { const FIELDS: &[&str] = &[ "UNSPECIFIED", - "RANK", - "DENSE_RANK", - "PERCENT_RANK", - "CUME_DIST", - "NTILE", - "LAG", - "LEAD", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", @@ -1735,13 +1721,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { match value { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), - "RANK" => Ok(BuiltInWindowFunction::Rank), - "DENSE_RANK" => Ok(BuiltInWindowFunction::DenseRank), - "PERCENT_RANK" => Ok(BuiltInWindowFunction::PercentRank), - "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), - "NTILE" => Ok(BuiltInWindowFunction::Ntile), - "LAG" => Ok(BuiltInWindowFunction::Lag), - "LEAD" => Ok(BuiltInWindowFunction::Lead), "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), @@ -2321,145 +2300,6 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { deserializer.deserialize_struct("datafusion.ColumnIndex", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnUnnestExec { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.column.is_some() { - len += 1; - } - if self.unnest_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestExec", len)?; - if let Some(v) = self.column.as_ref() { - struct_ser.serialize_field("column", v)?; - } - if let Some(v) = self.unnest_type.as_ref() { - match v { - column_unnest_exec::UnnestType::List(v) => { - struct_ser.serialize_field("list", v)?; - } - column_unnest_exec::UnnestType::Struct(v) => { - struct_ser.serialize_field("struct", v)?; - } - column_unnest_exec::UnnestType::Inferred(v) => { - struct_ser.serialize_field("inferred", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ColumnUnnestExec { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "column", - "list", - "struct", - "inferred", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Column, - List, - Struct, - Inferred, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "column" => Ok(GeneratedField::Column), - "list" => Ok(GeneratedField::List), - "struct" => Ok(GeneratedField::Struct), - "inferred" => Ok(GeneratedField::Inferred), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnUnnestExec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnUnnestExec") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut column__ = None; - let mut unnest_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Column => { - if column__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - column__ = map_.next_value()?; - } - GeneratedField::List => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::List) -; - } - GeneratedField::Struct => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("struct")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Struct) -; - } - GeneratedField::Inferred => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inferred")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Inferred) -; - } - } - } - Ok(ColumnUnnestExec { - column: column__, - unnest_type: unnest_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.ColumnUnnestExec", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for ColumnUnnestListItem { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -3202,6 +3042,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { len += 1; } + if self.temporary { + len += 1; + } if !self.definition.is_empty() { len += 1; } @@ -3239,6 +3082,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } @@ -3276,6 +3122,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "tablePartitionCols", "if_not_exists", "ifNotExists", + "temporary", "definition", "order_exprs", "orderExprs", @@ -3294,6 +3141,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Schema, TablePartitionCols, IfNotExists, + Temporary, Definition, OrderExprs, Unbounded, @@ -3327,6 +3175,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "schema" => Ok(GeneratedField::Schema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "temporary" => Ok(GeneratedField::Temporary), "definition" => Ok(GeneratedField::Definition), "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), "unbounded" => Ok(GeneratedField::Unbounded), @@ -3358,6 +3207,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut schema__ = None; let mut table_partition_cols__ = None; let mut if_not_exists__ = None; + let mut temporary__ = None; let mut definition__ = None; let mut order_exprs__ = None; let mut unbounded__ = None; @@ -3402,6 +3252,12 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } if_not_exists__ = Some(map_.next_value()?); } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); @@ -3451,6 +3307,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { schema: schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), if_not_exists: if_not_exists__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), definition: definition__.unwrap_or_default(), order_exprs: order_exprs__.unwrap_or_default(), unbounded: unbounded__.unwrap_or_default(), @@ -3480,6 +3337,9 @@ impl serde::Serialize for CreateViewNode { if self.or_replace { len += 1; } + if self.temporary { + len += 1; + } if !self.definition.is_empty() { len += 1; } @@ -3493,6 +3353,9 @@ impl serde::Serialize for CreateViewNode { if self.or_replace { struct_ser.serialize_field("orReplace", &self.or_replace)?; } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } @@ -3510,6 +3373,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { "input", "or_replace", "orReplace", + "temporary", "definition", ]; @@ -3518,6 +3382,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { Name, Input, OrReplace, + Temporary, Definition, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -3543,6 +3408,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { "name" => Ok(GeneratedField::Name), "input" => Ok(GeneratedField::Input), "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), + "temporary" => Ok(GeneratedField::Temporary), "definition" => Ok(GeneratedField::Definition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -3566,6 +3432,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { let mut name__ = None; let mut input__ = None; let mut or_replace__ = None; + let mut temporary__ = None; let mut definition__ = None; while let Some(k) = map_.next_key()? { match k { @@ -3587,6 +3454,12 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { } or_replace__ = Some(map_.next_value()?); } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); @@ -3599,6 +3472,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { name: name__, input: input__, or_replace: or_replace__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), definition: definition__.unwrap_or_default(), }) } @@ -5832,10 +5706,10 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { len += 1; } - if self.overwrite { + if self.keep_partition_by_columns { len += 1; } - if self.keep_partition_by_columns { + if self.insert_op != 0 { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; @@ -5854,12 +5728,14 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if self.overwrite { - struct_ser.serialize_field("overwrite", &self.overwrite)?; - } if self.keep_partition_by_columns { struct_ser.serialize_field("keepPartitionByColumns", &self.keep_partition_by_columns)?; } + if self.insert_op != 0 { + let v = InsertOp::try_from(self.insert_op) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.insert_op)))?; + struct_ser.serialize_field("insertOp", &v)?; + } struct_ser.end() } } @@ -5880,9 +5756,10 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema", "table_partition_cols", "tablePartitionCols", - "overwrite", "keep_partition_by_columns", "keepPartitionByColumns", + "insert_op", + "insertOp", ]; #[allow(clippy::enum_variant_names)] @@ -5892,8 +5769,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { TablePaths, OutputSchema, TablePartitionCols, - Overwrite, KeepPartitionByColumns, + InsertOp, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5920,8 +5797,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "overwrite" => Ok(GeneratedField::Overwrite), "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), + "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5946,8 +5823,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut table_paths__ = None; let mut output_schema__ = None; let mut table_partition_cols__ = None; - let mut overwrite__ = None; let mut keep_partition_by_columns__ = None; + let mut insert_op__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::ObjectStoreUrl => { @@ -5980,18 +5857,18 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::Overwrite => { - if overwrite__.is_some() { - return Err(serde::de::Error::duplicate_field("overwrite")); - } - overwrite__ = Some(map_.next_value()?); - } GeneratedField::KeepPartitionByColumns => { if keep_partition_by_columns__.is_some() { return Err(serde::de::Error::duplicate_field("keepPartitionByColumns")); } keep_partition_by_columns__ = Some(map_.next_value()?); } + GeneratedField::InsertOp => { + if insert_op__.is_some() { + return Err(serde::de::Error::duplicate_field("insertOp")); + } + insert_op__ = Some(map_.next_value::()? as i32); + } } } Ok(FileSinkConfig { @@ -6000,8 +5877,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { table_paths: table_paths__.unwrap_or_default(), output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), - overwrite: overwrite__.unwrap_or_default(), keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), + insert_op: insert_op__.unwrap_or_default(), }) } } @@ -7198,6 +7075,80 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InsertOp { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for InsertOp { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Append", + "Overwrite", + "Replace", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InsertOp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Append" => Ok(InsertOp::Append), + "Overwrite" => Ok(InsertOp::Overwrite), + "Replace" => Ok(InsertOp::Replace), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for InterleaveExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17550,6 +17501,135 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RecursionUnnestOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.output_column.is_some() { + len += 1; + } + if self.input_column.is_some() { + len += 1; + } + if self.depth != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RecursionUnnestOption", len)?; + if let Some(v) = self.output_column.as_ref() { + struct_ser.serialize_field("outputColumn", v)?; + } + if let Some(v) = self.input_column.as_ref() { + struct_ser.serialize_field("inputColumn", v)?; + } + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "output_column", + "outputColumn", + "input_column", + "inputColumn", + "depth", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OutputColumn, + InputColumn, + Depth, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "outputColumn" | "output_column" => Ok(GeneratedField::OutputColumn), + "inputColumn" | "input_column" => Ok(GeneratedField::InputColumn), + "depth" => Ok(GeneratedField::Depth), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RecursionUnnestOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RecursionUnnestOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut output_column__ = None; + let mut input_column__ = None; + let mut depth__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OutputColumn => { + if output_column__.is_some() { + return Err(serde::de::Error::duplicate_field("outputColumn")); + } + output_column__ = map_.next_value()?; + } + GeneratedField::InputColumn => { + if input_column__.is_some() { + return Err(serde::de::Error::duplicate_field("inputColumn")); + } + input_column__ = map_.next_value()?; + } + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); + } + depth__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(RecursionUnnestOption { + output_column: output_column__, + input_column: input_column__, + depth: depth__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RecursionUnnestOption", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for RepartitionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20472,10 +20552,16 @@ impl serde::Serialize for UnnestOptions { if self.preserve_nulls { len += 1; } + if !self.recursions.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.UnnestOptions", len)?; if self.preserve_nulls { struct_ser.serialize_field("preserveNulls", &self.preserve_nulls)?; } + if !self.recursions.is_empty() { + struct_ser.serialize_field("recursions", &self.recursions)?; + } struct_ser.end() } } @@ -20488,11 +20574,13 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { const FIELDS: &[&str] = &[ "preserve_nulls", "preserveNulls", + "recursions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { PreserveNulls, + Recursions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20515,6 +20603,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { { match value { "preserveNulls" | "preserve_nulls" => Ok(GeneratedField::PreserveNulls), + "recursions" => Ok(GeneratedField::Recursions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20535,6 +20624,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { V: serde::de::MapAccess<'de>, { let mut preserve_nulls__ = None; + let mut recursions__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::PreserveNulls => { @@ -20543,10 +20633,17 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { } preserve_nulls__ = Some(map_.next_value()?); } + GeneratedField::Recursions => { + if recursions__.is_some() { + return Err(serde::de::Error::duplicate_field("recursions")); + } + recursions__ = Some(map_.next_value()?); + } } } Ok(UnnestOptions { preserve_nulls: preserve_nulls__.unwrap_or_default(), + recursions: recursions__.unwrap_or_default(), }) } } @@ -21212,7 +21309,7 @@ impl serde::Serialize for WindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.exprs.is_empty() { len += 1; } if !self.partition_by.is_empty() { @@ -21231,8 +21328,8 @@ impl serde::Serialize for WindowExprNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.exprs.is_empty() { + struct_ser.serialize_field("exprs", &self.exprs)?; } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; @@ -21273,7 +21370,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "exprs", "partition_by", "partitionBy", "order_by", @@ -21290,7 +21387,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Exprs, PartitionBy, OrderBy, WindowFrame, @@ -21319,7 +21416,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "exprs" => Ok(GeneratedField::Exprs), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), @@ -21346,7 +21443,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut exprs__ = None; let mut partition_by__ = None; let mut order_by__ = None; let mut window_frame__ = None; @@ -21354,11 +21451,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Exprs => { + if exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("exprs")); } - expr__ = map_.next_value()?; + exprs__ = Some(map_.next_value()?); } GeneratedField::PartitionBy => { if partition_by__.is_some() { @@ -21407,7 +21504,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } } Ok(WindowExprNode { - expr: expr__, + exprs: exprs__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d5df037bef1d4..18c94ff4c6e4b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -240,6 +240,8 @@ pub struct CreateExternalTableNode { pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(bool, tag = "6")] pub if_not_exists: bool, + #[prost(bool, tag = "14")] + pub temporary: bool, #[prost(string, tag = "7")] pub definition: ::prost::alloc::string::String, #[prost(message, repeated, tag = "10")] @@ -303,6 +305,8 @@ pub struct CreateViewNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(bool, tag = "3")] pub or_replace: bool, + #[prost(bool, tag = "6")] + pub temporary: bool, #[prost(string, tag = "4")] pub definition: ::prost::alloc::string::String, } @@ -396,7 +400,7 @@ pub struct UnnestNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub exec_columns: ::prost::alloc::vec::Vec, + pub exec_columns: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] pub list_type_columns: ::prost::alloc::vec::Vec, #[prost(uint64, repeated, tag = "4")] @@ -428,28 +432,20 @@ pub struct ColumnUnnestListRecursion { pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnUnnestExec { - #[prost(message, optional, tag = "1")] - pub column: ::core::option::Option, - #[prost(oneof = "column_unnest_exec::UnnestType", tags = "2, 3, 4")] - pub unnest_type: ::core::option::Option, -} -/// Nested message and enum types in `ColumnUnnestExec`. -pub mod column_unnest_exec { - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum UnnestType { - #[prost(message, tag = "2")] - List(super::ColumnUnnestListRecursions), - #[prost(message, tag = "3")] - Struct(super::super::datafusion_common::EmptyMessage), - #[prost(message, tag = "4")] - Inferred(super::super::datafusion_common::EmptyMessage), - } -} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct UnnestOptions { #[prost(bool, tag = "1")] pub preserve_nulls: bool, + #[prost(message, repeated, tag = "2")] + pub recursions: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RecursionUnnestOption { + #[prost(message, optional, tag = "1")] + pub output_column: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub input_column: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { @@ -534,7 +530,7 @@ pub mod logical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "18")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::WindowExprNode), /// AggregateUDF expressions #[prost(message, tag = "19")] AggregateUdfExpr(::prost::alloc::boxed::Box), @@ -731,8 +727,8 @@ pub struct ScalarUdfExprNode { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub exprs: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] @@ -1067,10 +1063,10 @@ pub struct FileSinkConfig { pub output_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub table_partition_cols: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "8")] - pub overwrite: bool, #[prost(bool, tag = "9")] pub keep_partition_by_columns: bool, + #[prost(enumeration = "InsertOp", tag = "10")] + pub insert_op: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JsonSink { @@ -1834,13 +1830,13 @@ pub enum BuiltInWindowFunction { /// Unspecified = 0, /// ROW_NUMBER = 0; - Rank = 1, - DenseRank = 2, - PercentRank = 3, - CumeDist = 4, - Ntile = 5, - Lag = 6, - Lead = 7, + /// RANK = 1; + /// DENSE_RANK = 2; + /// PERCENT_RANK = 3; + /// CUME_DIST = 4; + /// NTILE = 5; + /// LAG = 6; + /// LEAD = 7; FirstValue = 8, LastValue = 9, NthValue = 10, @@ -1853,13 +1849,6 @@ impl BuiltInWindowFunction { pub fn as_str_name(&self) -> &'static str { match self { Self::Unspecified => "UNSPECIFIED", - Self::Rank => "RANK", - Self::DenseRank => "DENSE_RANK", - Self::PercentRank => "PERCENT_RANK", - Self::CumeDist => "CUME_DIST", - Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1869,13 +1858,6 @@ impl BuiltInWindowFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "UNSPECIFIED" => Some(Self::Unspecified), - "RANK" => Some(Self::Rank), - "DENSE_RANK" => Some(Self::DenseRank), - "PERCENT_RANK" => Some(Self::PercentRank), - "CUME_DIST" => Some(Self::CumeDist), - "NTILE" => Some(Self::Ntile), - "LAG" => Some(Self::Lag), - "LEAD" => Some(Self::Lead), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), "NTH_VALUE" => Some(Self::NthValue), @@ -1969,6 +1951,35 @@ impl DateUnit { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum InsertOp { + Append = 0, + Overwrite = 1, + Replace = 2, +} +impl InsertOp { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Append" => Some(Self::Append), + "Overwrite" => Some(Self::Overwrite), + "Replace" => Some(Self::Replace), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 0f9f9d335afe7..02be3e11c1cbe 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -161,7 +161,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -179,17 +179,15 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + _ctx: &SessionContext, + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -273,7 +271,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -291,17 +289,15 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + _ctx: &SessionContext, + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -407,6 +403,7 @@ impl TableParquetOptionsProto { maximum_parallel_row_group_writers: global_options.global.maximum_parallel_row_group_writers as u64, maximum_buffered_record_batches_per_stream: global_options.global.maximum_buffered_record_batches_per_stream as u64, schema_force_view_types: global_options.global.schema_force_view_types, + binary_as_string: global_options.global.binary_as_string, }), column_specific_options: column_specific_options.into_iter().map(|(column_name, options)| { ParquetColumnSpecificOptions { @@ -497,6 +494,7 @@ impl From<&ParquetOptionsProto> for ParquetOptions { maximum_parallel_row_group_writers: proto.maximum_parallel_row_group_writers as usize, maximum_buffered_record_batches_per_stream: proto.maximum_buffered_record_batches_per_stream as usize, schema_force_view_types: proto.schema_force_view_types, + binary_as_string: proto.binary_as_string, } } } @@ -572,7 +570,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -590,17 +588,15 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + _ctx: &SessionContext, + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -662,7 +658,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -680,17 +676,15 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + _ctx: &SessionContext, + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -722,7 +716,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -740,17 +734,15 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _cts: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + _cts: &SessionContext, + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 893255ccc8ce0..27bda7dd5ace6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, - TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, + Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; @@ -56,6 +56,15 @@ impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: r.input_column.as_ref().unwrap().into(), + output_column: r.output_column.as_ref().unwrap().into(), + depth: r.depth as usize, + }) + .collect::>(), } } } @@ -142,14 +151,7 @@ impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), - protobuf::BuiltInWindowFunction::Rank => Self::Rank, - protobuf::BuiltInWindowFunction::PercentRank => Self::PercentRank, - protobuf::BuiltInWindowFunction::DenseRank => Self::DenseRank, - protobuf::BuiltInWindowFunction::Lag => Self::Lag, - protobuf::BuiltInWindowFunction::Lead => Self::Lead, protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, - protobuf::BuiltInWindowFunction::CumeDist => Self::CumeDist, - protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, protobuf::BuiltInWindowFunction::NthValue => Self::NthValue, protobuf::BuiltInWindowFunction::LastValue => Self::LastValue, } @@ -289,10 +291,7 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -312,10 +311,7 @@ pub fn parse_expr( None => registry.udaf(udaf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, @@ -332,10 +328,7 @@ pub fn parse_expr( None => registry.udwf(udwf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 7156cee66affa..b90ae88aa74ab 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -19,11 +19,10 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; -use crate::protobuf::column_unnest_exec::UnnestType; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestExec, ColumnUnnestListItem, ColumnUnnestListRecursion, - ColumnUnnestListRecursions, CustomTableScanNode, SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, + SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -62,15 +61,14 @@ use datafusion_expr::{ dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, - EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, - Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, + Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, + SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, ColumnUnnestType, Unnest}; -use datafusion_proto_common::EmptyMessage; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, FetchType, SkipType, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -283,6 +281,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, _>>() .map_err(|e| e.into()) }?; + LogicalPlanBuilder::values(values)?.build() } LogicalPlanType::Projection(projection) => { @@ -451,7 +450,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - LogicalPlanType::CustomScan(scan) => { + CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -579,6 +578,7 @@ impl AsLogicalPlan for LogicalPlanNode { .clone(), order_exprs, if_not_exists: create_extern_table.if_not_exists, + temporary: create_extern_table.temporary, definition, unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), @@ -601,6 +601,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { name: from_table_reference(create_view.name.as_ref(), "CreateView")?, + temporary: create_view.temporary, input: Arc::new(plan), or_replace: create_view.or_replace, definition, @@ -843,13 +844,13 @@ impl AsLogicalPlan for LogicalPlanNode { .prepare(prepare.name.clone(), data_types)? .build() } - LogicalPlanType::DropView(dropview) => Ok(datafusion_expr::LogicalPlan::Ddl( - datafusion_expr::DdlStatement::DropView(DropView { + LogicalPlanType::DropView(dropview) => { + Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { name: from_table_reference(dropview.name.as_ref(), "DropView")?, if_exists: dropview.if_exists, schema: Arc::new(convert_required!(dropview.schema)?), - }), - )), + }))) + } LogicalPlanType::CopyTo(copy) => { let input: LogicalPlan = into_logical_plan!(copy.input, ctx, extension_codec)?; @@ -858,48 +859,20 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec.try_decode_file_format(©.file_type, ctx)?, ); - Ok(datafusion_expr::LogicalPlan::Copy( - datafusion_expr::dml::CopyTo { - input: Arc::new(input), - output_url: copy.output_url.clone(), - partition_by: copy.partition_by.clone(), - file_type, - options: Default::default(), - }, - )) + Ok(LogicalPlan::Copy(dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + partition_by: copy.partition_by.clone(), + file_type, + options: Default::default(), + })) } LogicalPlanType::Unnest(unnest) => { let input: LogicalPlan = into_logical_plan!(unnest.input, ctx, extension_codec)?; - Ok(datafusion_expr::LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), - exec_columns: unnest - .exec_columns - .iter() - .map(|c| { - ( - c.column.as_ref().unwrap().to_owned().into(), - match c.unnest_type.as_ref().unwrap() { - UnnestType::Inferred(_) => ColumnUnnestType::Inferred, - UnnestType::Struct(_) => ColumnUnnestType::Struct, - UnnestType::List(l) => ColumnUnnestType::List( - l.recursions - .iter() - .map(|ul| ColumnUnnestList { - output_column: ul - .output_column - .as_ref() - .unwrap() - .to_owned() - .into(), - depth: ul.depth as usize, - }) - .collect(), - ), - }, - ) - }) - .collect(), + exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest .list_type_columns .iter() @@ -951,7 +924,7 @@ impl AsLogicalPlan for LogicalPlanNode { } as u64; let values_list = serialize_exprs(values.iter().flatten(), extension_codec)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( protobuf::ValuesNode { n_cols, @@ -1043,7 +1016,7 @@ impl AsLogicalPlan for LogicalPlanNode { exprs_vec.push(expr_vec); } - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { file_format_type: Some(file_format_type), @@ -1069,12 +1042,12 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { table_name: Some(table_name.clone().into()), input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( view_table.logical_plan(), extension_codec, )?, @@ -1107,11 +1080,11 @@ impl AsLogicalPlan for LogicalPlanNode { } } LogicalPlan::Projection(Projection { expr, input, .. }) => { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Projection(Box::new( protobuf::ProjectionNode { input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, )?, @@ -1123,12 +1096,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Filter(filter) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - filter.input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + filter.input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), @@ -1141,12 +1113,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Distinct(Distinct::All(input)) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Distinct(Box::new( protobuf::DistinctNode { input: Some(Box::new(input)), @@ -1161,16 +1132,15 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. })) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; let sort_expr = match sort_expr { None => vec![], Some(sort_expr) => serialize_sorts(sort_expr, extension_codec)?, }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( protobuf::DistinctOnNode { on_expr: serialize_exprs(on_expr, extension_codec)?, @@ -1184,12 +1154,11 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Window(Window { input, window_expr, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Window(Box::new( protobuf::WindowNode { input: Some(Box::new(input)), @@ -1204,12 +1173,11 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Aggregate(Box::new( protobuf::AggregateNode { input: Some(Box::new(input)), @@ -1229,16 +1197,14 @@ impl AsLogicalPlan for LogicalPlanNode { null_equals_null, .. }) => { - let left: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; + let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + left.as_ref(), + extension_codec, + )?; + let right: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + right.as_ref(), + extension_codec, + )?; let (left_join_key, right_join_key) = on .iter() .map(|(l, r)| { @@ -1257,7 +1223,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), @@ -1276,12 +1242,11 @@ impl AsLogicalPlan for LogicalPlanNode { not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::SubqueryAlias(Box::new( protobuf::SubqueryAliasNode { input: Some(Box::new(input)), @@ -1290,31 +1255,40 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(Limit { input, skip, fetch }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + LogicalPlan::Limit(limit) => { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + limit.input.as_ref(), + extension_codec, + )?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal skip values", + )); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal fetch values", + )); + }; + + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Limit(Box::new( protobuf::LimitNode { input: Some(Box::new(input)), - skip: *skip as i64, + skip: skip as i64, fetch: fetch.unwrap_or(i64::MAX as usize) as i64, }, ))), }) } LogicalPlan::Sort(Sort { input, expr, fetch }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; let sort_expr: Vec = serialize_sorts(expr, extension_codec)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( protobuf::SortNode { input: Some(Box::new(input)), @@ -1329,11 +1303,10 @@ impl AsLogicalPlan for LogicalPlanNode { partitioning_scheme, }) => { use datafusion::logical_expr::Partitioning; - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; // Assumed common usize field was batch size // Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways @@ -1354,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode { } }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Repartition(Box::new( protobuf::RepartitionNode { input: Some(Box::new(input)), @@ -1365,7 +1338,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, .. - }) => Ok(protobuf::LogicalPlanNode { + }) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::EmptyRelation( protobuf::EmptyRelationNode { produce_one_row: *produce_one_row, @@ -1386,6 +1359,7 @@ impl AsLogicalPlan for LogicalPlanNode { options, constraints, column_defaults, + temporary, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1403,7 +1377,7 @@ impl AsLogicalPlan for LogicalPlanNode { .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { name: Some(name.clone().into()), @@ -1412,6 +1386,7 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Some(df_schema.try_into()?), table_partition_cols: table_partition_cols.clone(), if_not_exists: *if_not_exists, + temporary: *temporary, order_exprs: converted_order_exprs, definition: definition.clone().unwrap_or_default(), unbounded: *unbounded, @@ -1427,7 +1402,8 @@ impl AsLogicalPlan for LogicalPlanNode { input, or_replace, definition, - })) => Ok(protobuf::LogicalPlanNode { + temporary, + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { name: Some(name.clone().into()), @@ -1436,6 +1412,7 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?)), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone().unwrap_or_default(), }, ))), @@ -1446,7 +1423,7 @@ impl AsLogicalPlan for LogicalPlanNode { if_not_exists, schema: df_schema, }, - )) => Ok(protobuf::LogicalPlanNode { + )) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( protobuf::CreateCatalogSchemaNode { schema_name: schema_name.clone(), @@ -1459,7 +1436,7 @@ impl AsLogicalPlan for LogicalPlanNode { catalog_name, if_not_exists, schema: df_schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalog( protobuf::CreateCatalogNode { catalog_name: catalog_name.clone(), @@ -1469,11 +1446,11 @@ impl AsLogicalPlan for LogicalPlanNode { )), }), LogicalPlan::Analyze(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Analyze(Box::new( protobuf::AnalyzeNode { input: Some(Box::new(input)), @@ -1483,11 +1460,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Explain(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Explain(Box::new( protobuf::ExplainNode { input: Some(Box::new(input)), @@ -1500,37 +1477,14 @@ impl AsLogicalPlan for LogicalPlanNode { let inputs: Vec = union .inputs .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Union( protobuf::UnionNode { inputs }, )), }) } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left = protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right = protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { - logical_plan_type: Some(LogicalPlanType::CrossJoin(Box::new( - protobuf::CrossJoinNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - }, - ))), - }) - } LogicalPlan::Extension(extension) => { let mut buf: Vec = vec![]; extension_codec.try_encode(extension, &mut buf)?; @@ -1539,15 +1493,10 @@ impl AsLogicalPlan for LogicalPlanNode { .node .inputs() .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Extension( LogicalExtensionNode { node: buf, inputs }, )), @@ -1558,11 +1507,9 @@ impl AsLogicalPlan for LogicalPlanNode { data_types, input, }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Prepare(Box::new( protobuf::PrepareNode { name: name.clone(), @@ -1584,10 +1531,8 @@ impl AsLogicalPlan for LogicalPlanNode { schema, options, }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let proto_unnest_list_items = list_type_columns .iter() .map(|(index, ul)| ColumnUnnestListItem { @@ -1598,38 +1543,13 @@ impl AsLogicalPlan for LogicalPlanNode { }), }) .collect(); - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Unnest(Box::new( protobuf::UnnestNode { input: Some(Box::new(input)), exec_columns: exec_columns .iter() - .map(|(col, unnesting)| ColumnUnnestExec { - column: Some(col.into()), - unnest_type: Some(match unnesting { - ColumnUnnestType::Inferred => { - UnnestType::Inferred(EmptyMessage {}) - } - ColumnUnnestType::Struct => { - UnnestType::Struct(EmptyMessage {}) - } - ColumnUnnestType::List(list) => { - UnnestType::List(ColumnUnnestListRecursions { - recursions: list - .iter() - .map(|ul| ColumnUnnestListRecursion { - output_column: Some( - ul.output_column - .to_owned() - .into(), - ), - depth: ul.depth as _, - }) - .collect(), - }) - } - }), - }) + .map(|col| col.into()) .collect(), list_type_columns: proto_unnest_list_items, struct_type_columns: struct_type_columns @@ -1659,7 +1579,7 @@ impl AsLogicalPlan for LogicalPlanNode { name, if_exists, schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { name: Some(name.clone().into()), @@ -1690,15 +1610,13 @@ impl AsLogicalPlan for LogicalPlanNode { partition_by, .. }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let mut buf = Vec::new(); extension_codec .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( protobuf::CopyToNode { input: Some(Box::new(input)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 63d1a007c1e55..5a6f3a32c668e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ WindowFrameUnits, WindowFunctionDefinition, }; +use crate::protobuf::RecursionUnnestOption; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -49,6 +50,15 @@ impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: Some((&r.input_column).into()), + output_column: Some((&r.output_column).into()), + depth: r.depth as u32, + }) + .collect(), } } } @@ -117,13 +127,6 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::FirstValue => Self::FirstValue, BuiltInWindowFunction::LastValue => Self::LastValue, BuiltInWindowFunction::NthValue => Self::NthValue, - BuiltInWindowFunction::Ntile => Self::Ntile, - BuiltInWindowFunction::CumeDist => Self::CumeDist, - BuiltInWindowFunction::PercentRank => Self::PercentRank, - BuiltInWindowFunction::Rank => Self::Rank, - BuiltInWindowFunction::Lag => Self::Lag, - BuiltInWindowFunction::Lead => Self::Lead, - BuiltInWindowFunction::DenseRank => Self::DenseRank, } } } @@ -336,25 +339,19 @@ pub fn serialize_expr( ) } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(serialize_expr(arg, codec)?)) - } else { - None - }; let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, + let window_expr = protobuf::WindowExprNode { + exprs: serialize_exprs(args, codec)?, window_function: Some(window_function), partition_by, order_by, window_frame, fun_definition, - }); + }; protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 00826d311a09c..6fbdf516f6c74 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; use chrono::{TimeZone, Utc}; +use datafusion_expr::dml::InsertOp; use object_store::path::Path; use object_store::ObjectMeta; @@ -660,13 +661,18 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { Ok((name.clone(), data_type)) }) .collect::>>()?; + let insert_op = match conf.insert_op() { + protobuf::InsertOp::Append => InsertOp::Append, + protobuf::InsertOp::Overwrite => InsertOp::Overwrite, + protobuf::InsertOp::Replace => InsertOp::Replace, + }; Ok(Self { object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, file_groups, table_paths, output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, - overwrite: conf.overwrite, + insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, }) } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9a6850cb21089..326c7acab3928 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -488,7 +488,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let physical_aggr_expr: Vec = hash_agg + let physical_aggr_expr: Vec> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) @@ -518,6 +518,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .with_distinct(agg_node.distinct) .order_by(ordering_req) .build() + .map(Arc::new) } } }).transpose()?.ok_or_else(|| { @@ -850,7 +851,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -897,7 +898,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -1712,9 +1713,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; @@ -1781,9 +1780,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 36c49f2f8ee03..f46761b268ed9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,9 +23,8 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, Rank, RankType, - ScalarRegexMatchExpr, TryCastExpr, WindowShift, + BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, + Literal, NegativeExpr, NotExpr, NthValue, ScalarRegexMatchExpr, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -49,7 +48,7 @@ use crate::protobuf::{ use super::PhysicalExtensionCodec; pub fn serialize_physical_aggr_expr( - aggr_expr: AggregateFunctionExpr, + aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; @@ -109,59 +108,24 @@ pub fn serialize_physical_window_expr( let expr = built_in_window_expr.get_built_in_func_expr(); let built_in_fn_expr = expr.as_any(); - let builtin_fn = if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() - { - match rank_expr.get_type() { - RankType::Basic => protobuf::BuiltInWindowFunction::Rank, - RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, - RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, - } - } else if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::CumeDist - } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { - args.insert( - 0, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - ntile_expr.get_n() as i64, - )))), - ); - protobuf::BuiltInWindowFunction::Ntile - } else if let Some(window_shift_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - window_shift_expr.get_shift_offset(), - )))), - ); - args.insert( - 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), - ); - - if window_shift_expr.get_shift_offset() >= 0 { - protobuf::BuiltInWindowFunction::Lag - } else { - protobuf::BuiltInWindowFunction::Lead - } - } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { - match nth_value_expr.get_kind() { - NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, - NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, - NthValueKind::Nth(n) => { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64( - Some(n), - ))), - ); - protobuf::BuiltInWindowFunction::NthValue + let builtin_fn = + if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new( + datafusion_common::ScalarValue::Int64(Some(n)), + )), + ); + protobuf::BuiltInWindowFunction::NthValue + } } - } - } else { - return not_impl_err!("BuiltIn function not supported: {expr:?}"); - }; + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; ( physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32), @@ -661,8 +625,8 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { table_paths, output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, - overwrite: conf.overwrite, keep_partition_by_columns: conf.keep_partition_by_columns, + insert_op: conf.insert_op as i32, }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6f513279f2f60..14d91913e7cdd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -47,7 +47,10 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; -use datafusion::functions_window::row_number::row_number; +use datafusion::functions_window::expr_fn::{ + cume_dist, dense_rank, lag, lead, ntile, percent_rank, rank, row_number, +}; +use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -73,9 +76,9 @@ use datafusion_functions_aggregate::expr_fn::{ approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, nth_value, }; -use datafusion_functions_aggregate::kurtosis_pop::kurtosis_pop; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, @@ -937,8 +940,18 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(1), lit(2), lit(3)], vec![lit(10), lit(20), lit(30)], ), + cume_dist(), row_number(), - kurtosis_pop(lit(1)), + rank(), + dense_rank(), + percent_rank(), + lead(col("b"), None, None), + lead(col("b"), Some(2), None), + lead(col("b"), Some(2), Some(ScalarValue::from(100))), + lag(col("b"), None, None), + lag(col("b"), Some(2), None), + lag(col("b"), Some(2), Some(ScalarValue::from(100))), + ntile(lit(3)), nth_value(col("b"), 1, vec![]), nth_value( col("b"), @@ -1062,6 +1075,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { expr: exprs.swap_remove(0), }) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug)] @@ -2154,7 +2171,7 @@ fn roundtrip_aggregate_udf() { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -2302,9 +2319,7 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2315,9 +2330,7 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2334,9 +2347,7 @@ fn roundtrip_window() { ); let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2384,7 +2395,7 @@ fn roundtrip_window() { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -2457,7 +2468,10 @@ fn roundtrip_window() { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { make_partition_evaluator() } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index db84a08e5b408..4a9bf6afb49e6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -27,6 +27,7 @@ use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_expr::dml::InsertOp; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; @@ -72,7 +73,6 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{ @@ -304,7 +304,8 @@ fn roundtrip_window() -> Result<()> { ) .schema(Arc::clone(&schema)) .alias("avg(b)") - .build()?, + .build() + .map(Arc::new)?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -320,7 +321,8 @@ fn roundtrip_window() -> Result<()> { let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) .schema(Arc::clone(&schema)) .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") - .build()?; + .build() + .map(Arc::new)?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, @@ -366,13 +368,13 @@ fn rountrip_aggregate() -> Result<()> { .alias("NTH_VALUE(b, 1)") .build()?; - let test_cases: Vec> = vec![ + let test_cases = vec![ // AVG - vec![avg_expr], + vec![Arc::new(avg_expr)], // NTH_VALUE - vec![nth_expr], + vec![Arc::new(nth_expr)], // STRING_AGG - vec![str_agg_expr], + vec![Arc::new(str_agg_expr)], ]; for aggregates in test_cases { @@ -399,12 +401,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("AVG(b)") - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -428,13 +431,14 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = vec![AggregateExprBuilder::new( + let aggregates = vec![AggregateExprBuilder::new( approx_percentile_cont_udaf(), vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) .alias("APPROX_PERCENTILE_CONT(b, 0.5)") - .build()?]; + .build() + .map(Arc::new)?]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -463,13 +467,14 @@ fn rountrip_aggregate_with_sort() -> Result<()> { }, }]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("ARRAY_AGG(b)") .order_by(sort_exprs) - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -530,12 +535,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("example_agg") - .build()?, + .build() + .map(Arc::new)?, ]; roundtrip_test_with_context( @@ -1000,7 +1006,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { AggregateExprBuilder::new(max_udaf(), vec![udf_expr as Arc]) .schema(schema.clone()) .alias("max") - .build()?; + .build() + .map(Arc::new)?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( @@ -1051,7 +1058,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) .schema(Arc::clone(&schema)) .alias("aggregate_udf") - .build()?; + .build() + .map(Arc::new)?; let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( @@ -1078,7 +1086,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { .alias("aggregate_udf") .distinct() .ignore_nulls() - .build()?; + .build() + .map(Arc::new)?; let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, @@ -1143,7 +1152,7 @@ fn roundtrip_json_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(JsonSink::new( @@ -1179,7 +1188,7 @@ fn roundtrip_csv_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(CsvSink::new( @@ -1238,7 +1247,7 @@ fn roundtrip_parquet_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 5c4b83fe38e11..1eef1b718ba6f 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,6 +46,7 @@ arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } @@ -55,6 +56,7 @@ strum = { version = "0.26.1", features = ["derive"] } ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } paste = "^1.0" diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 4c380f0b37a31..c288d6ca70674 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -98,8 +98,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - // Each recursive CTE consists from two parts in the logical plan: - // 1. A static term (the left hand side on the SQL, where the + // Each recursive CTE consists of two parts in the logical plan: + // 1. A static term (the left-hand side on the SQL, where the // referencing to the same CTE is not allowed) // // 2. A recursive term (the right hand side, and the recursive diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ddafc4e3a03a8..619eadcf0fb86 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -237,7 +237,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - // user-defined function (UDF) should have precedence + // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); @@ -260,12 +260,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - // then, window function + // Then, window function if let Some(WindowType::WindowSpec(window)) = over { let partition_by = window .partition_by .into_iter() - // ignore window spec PARTITION BY for scalar values + // Ignore window spec PARTITION BY for scalar values // as they do not change and thus do not generate new partitions .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) @@ -383,7 +383,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, name: &str, ) -> Result { - // check udaf first + // Check udaf first let udaf = self.context_provider.get_aggregate_meta(name); // Use the builtin window function instead of the user-defined aggregate function if udaf.as_ref().is_some_and(|udaf| { @@ -432,6 +432,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { qualifier: None, options: WildcardOptions::default(), }), + FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { + let qualifier = self.object_name_to_table_reference(object_name)?; + // Sanity check on qualifier with schema + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + if qualified_indices.is_empty() { + return plan_err!("Invalid qualifier {qualifier}"); + } + Ok(Expr::Wildcard { + qualifier: Some(qualifier), + options: WildcardOptions::default(), + }) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 36776c690235b..e103f68fc9275 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -19,8 +19,8 @@ use arrow_schema::Field; use sqlparser::ast::{Expr as SQLExpr, Ident}; use datafusion_common::{ - internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError, - Result, TableReference, + internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, + DataFusionError, Result, TableReference, }; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; @@ -113,18 +113,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|id| self.ident_normalizer.normalize(id)) .collect::>(); - // Currently not supporting more than one nested level - // Though ideally once that support is in place, this code should work with it - // TODO: remove when can support multiple nested identifiers - if ids.len() > 5 { - return not_impl_err!("Compound identifier: {ids:?}"); - } - let search_result = search_dfschema(&ids, schema); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure for planner in self.context_provider.get_expr_planners() { if let Ok(planner_result) = planner.plan_compound_identifier( field, @@ -133,36 +126,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) { match planner_result { PlannerResult::Planned(expr) => { - // sanity check on column - schema - .check_ambiguous_name(qualifier, field.name())?; return Ok(expr); } PlannerResult::Original(_args) => {} } } } - not_impl_err!( - "Compound identifiers not supported by ExprPlanner: {ids:?}" - ) + plan_err!("could not parse compound identifier from {ids:?}") } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - // sanity check on column - schema.check_ambiguous_name(qualifier, field.name())?; Ok(Expr::Column(Column::from((qualifier, field)))) } None => { - // return default where use all identifiers to not have a nested field + // Return default where use all identifiers to not have a nested field // this len check is because at 5 identifiers will have to have a nested field if ids.len() == 5 { not_impl_err!("compound identifier: {ids:?}") } else { - // check the outer_query_schema and try to find a match + // Check the outer_query_schema and try to find a match if let Some(outer) = planner_context.outer_query_schema() { let search_result = search_dfschema(&ids, outer); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { @@ -172,15 +158,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Column::from((qualifier, field)).quoted_flat_name() ) } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column Ok(Expr::OuterReferenceColumn( field.data_type().clone(), Column::from((qualifier, field)), )) } - // found no matching field, will return a default + // Found no matching field, will return a default None => { let s = &ids[0..ids.len()]; // safe unwrap as s can never be empty or exceed the bounds @@ -191,11 +177,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } else { let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds + // Safe unwrap as s can never be empty or exceed the bounds let (relation, column_name) = form_identifier(s).unwrap(); - // sanity check on column - schema - .check_ambiguous_name(relation.as_ref(), column_name)?; Ok(Expr::Column(Column::new(relation, column_name))) } } @@ -328,15 +311,15 @@ fn search_dfschema<'ids, 'schema>( fn generate_schema_search_terms( ids: &[String], ) -> impl Iterator, &String, &[String])> { - // take at most 4 identifiers to form a Column to search with + // Take at most 4 identifiers to form a Column to search with // - 1 for the column name // - 0 to 3 for the TableReference let bound = ids.len().min(4); - // search terms from most specific to least specific + // Search terms from most specific to least specific (0..bound).rev().map(|i| { let nested_names_index = i + 1; let qualifier_and_column = &ids[0..nested_names_index]; - // safe unwrap as qualifier_and_column can never be empty or exceed the bounds + // Safe unwrap as qualifier_and_column can never be empty or exceed the bounds let (relation, column_name) = form_identifier(qualifier_and_column).unwrap(); (relation, column_name, &ids[nested_names_index..]) }) @@ -348,7 +331,7 @@ mod test { #[test] // testing according to documentation of generate_schema_search_terms function - // where ensure generated search terms are in correct order with correct values + // where it ensures generated search terms are in correct order with correct values fn test_generate_schema_search_terms() -> Result<()> { type ExpectedItem = ( Option, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 34e119c45fdfe..432e8668c52e9 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -57,7 +57,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(sqlparser::ast::BinaryOperator), + Operator(BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 6a3a4d6ccbb75..00289806876fe 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -102,7 +102,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr_vec.push(Sort::new( expr, asc, - // when asc is true, by default nulls last to be consistent with postgres + // When asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), )) diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 2a341fb7c4467..06988eb03893b 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -16,8 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DFSchema, Result}; -use datafusion_expr::Expr; +use datafusion_common::{not_impl_err, plan_err, DFSchema, Result}; +use datafusion_expr::{ + type_coercion::{is_interval, is_timestamp}, + Expr, ExprSchemable, +}; use sqlparser::ast::{Expr as SQLExpr, UnaryOperator, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -33,11 +36,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), UnaryOperator::Plus => { - Ok(self.sql_expr_to_logical_expr(expr, schema, planner_context)?) + let operand = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let (data_type, _) = operand.data_type_and_nullable(schema)?; + if data_type.is_numeric() + || is_interval(&data_type) + || is_timestamp(&data_type) + { + Ok(operand) + } else { + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types") + } } UnaryOperator::Minus => { match expr { - // optimization: if it's a number literal, we apply the negative operator + // Optimization: if it's a number literal, we apply the negative operator // here directly to calculate the new literal. SQLExpr::Value(Value::Number(n, _)) => { self.parse_sql_number(&n, true) @@ -45,7 +58,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Interval(interval) => { self.sql_interval_to_expr(true, interval) } - // not a literal, apply negative operator on expression + // Not a literal, apply negative operator on expression _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( expr, schema, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index be0909b58468f..7dc15de7ad710 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -235,7 +235,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let value = interval_literal(*interval.value, negative)?; // leading_field really means the unit if specified - // for example, "month" in `INTERVAL '5' month` + // For example, "month" in `INTERVAL '5' month` let value = match interval.leading_field.as_ref() { Some(leading_field) => format!("{value} {leading_field}"), None => value, @@ -323,9 +323,9 @@ const fn try_decode_hex_char(c: u8) -> Option { fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { // remove leading zeroes let trimmed = unsigned_number.trim_start_matches('0'); - // parse precision and scale, remove decimal point if exists + // Parse precision and scale, remove decimal point if exists let (precision, scale, replaced_str) = if trimmed == "." { - // special cases for numbers such as “0.”, “000.”, and so on. + // Special cases for numbers such as “0.”, “000.”, and so on. (1, 0, Cow::Borrowed("0")) } else if let Some(i) = trimmed.find('.') { ( @@ -334,7 +334,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { Cow::Owned(trimmed.replace('.', "")), ) } else { - // no decimal point, keep as is + // No decimal point, keep as is (trimmed.len(), 0, Cow::Borrowed(trimmed)) }; @@ -344,7 +344,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { ))) })?; - // check precision overflow + // Check precision overflow if precision as u8 > DECIMAL128_MAX_PRECISION { return Err(DataFusionError::from(ParserError(format!( "Cannot parse {replaced_str} as i128 when building decimal: precision overflow" diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 2df8d89c59bc8..8a984f1645e9b 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -181,7 +181,7 @@ pub(crate) type LexOrdering = Vec; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateExternalTable { /// Table name - pub name: String, + pub name: ObjectName, /// Optional schema pub columns: Vec, /// File type (Parquet, NDJSON, CSV, etc) @@ -194,6 +194,8 @@ pub struct CreateExternalTable { pub order_exprs: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// Infinite streams? pub unbounded: bool, /// Table(provider) specific options @@ -699,6 +701,10 @@ impl<'a> DFParser<'a> { &mut self, unbounded: bool, ) -> Result { + let temporary = self + .parser + .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) + .is_some(); self.parser.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parser @@ -761,10 +767,10 @@ impl<'a> DFParser<'a> { // Note that mixing both names and definitions is not allowed let peeked = self.parser.peek_nth_token(2); if peeked == Token::Comma || peeked == Token::RParen { - // list of column names + // List of column names builder.table_partition_cols = Some(self.parse_partitions()?) } else { - // list of column defs + // List of column defs let (cols, cons) = self.parse_columns()?; builder.table_partition_cols = Some( cols.iter().map(|col| col.name.to_string()).collect(), @@ -813,13 +819,14 @@ impl<'a> DFParser<'a> { } let create = CreateExternalTable { - name: table_name.to_string(), + name: table_name, columns, file_type: builder.file_type.unwrap(), location: builder.location.unwrap(), table_partition_cols: builder.table_partition_cols.unwrap_or(vec![]), order_exprs: builder.order_exprs, if_not_exists, + temporary, unbounded, options: builder.options.unwrap_or(Vec::new()), constraints, @@ -850,7 +857,7 @@ impl<'a> DFParser<'a> { options.push((key, value)); let comma = self.parser.consume_token(&Token::Comma); if self.parser.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard + // Allow a trailing comma, even though it's not in standard break; } else if !comma { return self.expected( @@ -915,14 +922,16 @@ mod tests { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; + let name = ObjectName(vec![Ident::from("t")]); let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -932,13 +941,14 @@ mod tests { // positive case: leading space let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' "; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -949,13 +959,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' ;"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -966,13 +977,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS (format.delimiter '|')"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![( "format.delimiter".into(), @@ -986,13 +998,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1, p2) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec!["p1".to_string(), "p2".to_string()], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1013,13 +1026,14 @@ mod tests { ]; for (sql, compression) in sqls { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![( "format.compression".into(), @@ -1033,13 +1047,14 @@ mod tests { // positive case: it is ok for parquet files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1049,13 +1064,14 @@ mod tests { // positive case: it is ok for parquet files to be other than upper case let sql = "CREATE EXTERNAL TABLE t STORED AS parqueT LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1065,13 +1081,14 @@ mod tests { // positive case: it is ok for avro files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS AVRO LOCATION 'foo.avro'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "AVRO".to_string(), location: "foo.avro".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1082,13 +1099,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE IF NOT EXISTS t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: true, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1099,7 +1117,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("p1", DataType::Int(None)), @@ -1109,6 +1127,7 @@ mod tests { table_partition_cols: vec!["p1".to_string()], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1132,13 +1151,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1') LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), location: "blahblah".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![("k1".into(), Value::SingleQuotedString("v1".into()))], constraints: vec![], @@ -1149,13 +1169,14 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2) LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), location: "blahblah".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![ ("k1".into(), Value::SingleQuotedString("v1".into())), @@ -1188,7 +1209,7 @@ mod tests { ]; for (sql, (asc, nulls_first)) in sqls.iter().zip(expected.into_iter()) { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), location: "foo.csv".into(), @@ -1203,6 +1224,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1214,7 +1236,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 ASC, c2 DESC NULLS FIRST) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), @@ -1243,6 +1265,7 @@ mod tests { }, ]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1253,7 +1276,7 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 - c2 ASC) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), @@ -1278,6 +1301,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1297,7 +1321,7 @@ mod tests { 'TRUNCATE' 'NO', 'format.has_header' 'true')"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("c2", DataType::Float(None)), @@ -1322,6 +1346,7 @@ mod tests { with_fill: None, }]], if_not_exists: true, + temporary: false, unbounded: true, options: vec![ ( diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 5cbe1d7c014ad..072d2320fccf8 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -117,7 +117,7 @@ impl ValueNormalizer { /// CTEs, Views, subqueries and PREPARE statements. The states include /// Common Table Expression (CTE) provided with WITH clause and /// Parameter Data Types provided with PREPARE statement and the query schema of the -/// outer query plan +/// outer query plan. /// /// # Cloning /// @@ -138,6 +138,8 @@ pub struct PlannerContext { /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. outer_from_schema: Option, + /// The query schema defined by the table + create_table_schema: Option, } impl Default for PlannerContext { @@ -154,6 +156,7 @@ impl PlannerContext { ctes: HashMap::new(), outer_query_schema: None, outer_from_schema: None, + create_table_schema: None, } } @@ -166,12 +169,12 @@ impl PlannerContext { self } - // return a reference to the outer queries schema + // Return a reference to the outer query's schema pub fn outer_query_schema(&self) -> Option<&DFSchema> { self.outer_query_schema.as_ref().map(|s| s.as_ref()) } - /// sets the outer query schema, returning the existing one, if + /// Sets the outer query schema, returning the existing one, if /// any pub fn set_outer_query_schema( &mut self, @@ -181,12 +184,24 @@ impl PlannerContext { schema } - // return a clone of the outer FROM schema + pub fn set_table_schema( + &mut self, + mut schema: Option, + ) -> Option { + std::mem::swap(&mut self.create_table_schema, &mut schema); + schema + } + + pub fn table_schema(&self) -> Option { + self.create_table_schema.clone() + } + + // Return a clone of the outer FROM schema pub fn outer_from_schema(&self) -> Option> { self.outer_from_schema.clone() } - /// sets the outer FROM schema, returning the existing one, if any + /// Sets the outer FROM schema, returning the existing one, if any pub fn set_outer_from_schema( &mut self, mut schema: Option, @@ -195,11 +210,11 @@ impl PlannerContext { schema } - /// extends the FROM schema, returning the existing one, if any + /// Extends the FROM schema, returning the existing one, if any pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> { - self.outer_from_schema = match self.outer_from_schema.as_ref() { - Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)), - None => Some(Arc::clone(schema)), + match self.outer_from_schema.as_mut() { + Some(from_schema) => Arc::make_mut(from_schema).merge(schema), + None => self.outer_from_schema = Some(Arc::clone(schema)), }; Ok(()) } @@ -209,7 +224,7 @@ impl PlannerContext { &self.prepare_param_data_types } - /// returns true if there is a Common Table Expression (CTE) / + /// Returns true if there is a Common Table Expression (CTE) / /// Subquery for the specified name pub fn contains_cte(&self, cte_name: &str) -> bool { self.ctes.contains_key(cte_name) @@ -387,12 +402,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type, _)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; Ok(DataType::new_list(inner_data_type, true)) } + SQLDataType::Array(ArrayElemTypeDef::SquareBracket( + inner_sql_type, + maybe_array_size, + )) => { + let inner_data_type = self.convert_data_type(inner_sql_type)?; + if let Some(array_size) = maybe_array_size { + Ok(DataType::new_fixed_size_list( + inner_data_type, + *array_size as i32, + true, + )) + } else { + Ok(DataType::new_list(inner_data_type, true)) + } + } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") } @@ -506,9 +535,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::CharVarying(_) | SQLDataType::CharacterLargeObject(_) | SQLDataType::CharLargeObject(_) - // precision is not supported + // Precision is not supported | SQLDataType::Timestamp(Some(_), _) - // precision is not supported + // Precision is not supported | SQLDataType::Time(Some(_), _) | SQLDataType::Dec(_) | SQLDataType::BigNumeric(_) @@ -572,7 +601,7 @@ pub fn object_name_to_table_reference( object_name: ObjectName, enable_normalization: bool, ) -> Result { - // use destructure to make it clear no fields on ObjectName are ignored + // Use destructure to make it clear no fields on ObjectName are ignored let ObjectName(idents) = object_name; idents_to_table_reference(idents, enable_normalization) } @@ -583,7 +612,7 @@ pub(crate) fn idents_to_table_reference( enable_normalization: bool, ) -> Result { struct IdentTaker(Vec); - /// take the next identifier from the back of idents, panic'ing if + /// Take the next identifier from the back of idents, panic'ing if /// there are none left impl IdentTaker { fn take(&mut self, enable_normalization: bool) -> String { diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 71328cfd018c3..1ef009132f9e3 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,15 +19,14 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; use datafusion_expr::expr::Sort; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, + CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, Query, SelectInto, - SetExpr, Value, + SetExpr, }; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -54,7 +53,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // so we need to process `SELECT` and `ORDER BY` together. let oby_exprs = to_order_by_exprs(query.order_by)?; let plan = self.select_to_plan(*select, oby_exprs, planner_context)?; - let plan = self.limit(plan, query.offset, query.limit)?; + let plan = + self.limit(plan, query.offset, query.limit, planner_context)?; // Process the `SELECT INTO` after `LIMIT`. self.select_into(plan, select_into) } @@ -69,7 +69,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None, )?; let plan = self.order_by(plan, order_by_rex)?; - self.limit(plan, query.offset, query.limit) + self.limit(plan, query.offset, query.limit, planner_context) } } } @@ -80,40 +80,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, skip: Option, fetch: Option, + planner_context: &mut PlannerContext, ) -> Result { if skip.is_none() && fetch.is_none() { return Ok(input); } - let skip = match skip { - Some(skip_expr) => { - let expr = self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "OFFSET")?; - convert_usize_with_check(n, "OFFSET") - } - _ => Ok(0), - }?; - - let fetch = match fetch { - Some(limit_expr) - if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => - { - let expr = self.sql_to_expr( - limit_expr, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "LIMIT")?; - Some(convert_usize_with_check(n, "LIMIT")?) - } - _ => None, - }; - - LogicalPlanBuilder::from(input).limit(skip, fetch)?.build() + // skip and fetch expressions are not allowed to reference columns from the input plan + let empty_schema = DFSchema::empty(); + + let skip = skip + .map(|o| self.sql_to_expr(o.value, &empty_schema, planner_context)) + .transpose()?; + let fetch = fetch + .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) + .transpose()?; + LogicalPlanBuilder::from(input) + .limit_by_expr(skip, fetch)? + .build() } /// Wrap the logical in a sort @@ -150,6 +134,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists: false, or_replace: false, + temporary: false, column_defaults: vec![], }, ))), @@ -158,54 +143,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -/// Retrieves the constant result of an expression, evaluating it if possible. -/// -/// This function takes an expression and an argument name as input and returns -/// a `Result` indicating either the constant result of the expression or an -/// error if the expression cannot be evaluated. -/// -/// # Arguments -/// -/// * `expr` - An `Expr` representing the expression to evaluate. -/// * `arg_name` - The name of the argument for error messages. -/// -/// # Returns -/// -/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, -/// or an `Err` variant containing an error message if evaluation fails. -/// -/// tracks a more general solution -fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { - match expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), - Expr::BinaryExpr(binary_expr) => { - let lhs = get_constant_result(&binary_expr.left, arg_name)?; - let rhs = get_constant_result(&binary_expr.right, arg_name)?; - let res = match binary_expr.op { - Operator::Plus => lhs + rhs, - Operator::Minus => lhs - rhs, - Operator::Multiply => lhs * rhs, - _ => return plan_err!("Unsupported operator for {arg_name} clause"), - }; - Ok(res) - } - _ => plan_err!("Unexpected expression in {arg_name} clause"), - } -} - -/// Converts an `i64` to `usize`, performing a boundary check. -fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { - if n < 0 { - plan_err!("{arg_name} must be >= 0, '{n}' was provided.") - } else { - Ok(n as usize) - } -} - /// Returns the order by expressions from the query. fn to_order_by_exprs(order_by: Option) -> Result> { let Some(OrderBy { exprs, interpolate }) = order_by else { - // if no order by, return an empty array + // If no order by, return an empty array. return Ok(vec![]); }; if let Some(_interpolate) = interpolate { diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 409533a3eaa58..3f34608e37565 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -151,7 +151,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } } - JoinConstraint::None => not_impl_err!("NONE constraint is not supported"), + JoinConstraint::None => LogicalPlanBuilder::from(left) + .join_on(right, join_type, [])? + .build(), } } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index f8ebb04f38103..256cc58e71dc4 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -70,7 +70,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build()?; (plan, alias) } else { - // normalize name and alias + // Normalize name and alias let table_ref = self.object_name_to_table_reference(name)?; let table_name = table_ref.to_string(); let cte = planner_context.get_cte(&table_name); @@ -163,7 +163,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { subquery: TableFactor, planner_context: &mut PlannerContext, ) -> Result { - // At this point for a syntacitally valid query the outer_from_schema is + // At this point for a syntactically valid query the outer_from_schema is // guaranteed to be set, so the `.unwrap()` call will never panic. This // is the case because we only call this method for lateral table // factors, and those can never be the first factor in a FROM list. This diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c93d9e6fc4357..80a08da5e35d6 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,8 +25,8 @@ use crate::utils::{ }; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, @@ -38,6 +38,7 @@ use datafusion_expr::{ qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use indexmap::IndexMap; use sqlparser::ast::{ Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, WildcardAdditionalOptions, WindowType, @@ -52,7 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_by: Vec, planner_context: &mut PlannerContext, ) -> Result { - // check for unsupported syntax first + // Check for unsupported syntax first if !select.cluster_by.is_empty() { return not_impl_err!("CLUSTER BY"); } @@ -69,17 +70,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("SORT BY"); } - // process `from` clause + // Process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); - // process `where` clause + // Process `where` clause let base_plan = self.plan_selection(select.selection, plan, planner_context)?; - // handle named windows before processing the projection expression + // Handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; match_window_definitions(&mut select.projection, &select.named_window)?; - // process the SELECT expressions + // Process the SELECT expressions let select_exprs = self.prepare_select_exprs( &base_plan, select.projection, @@ -87,7 +88,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - // having and group by clause may reference aliases defined in select projection + // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; // Place the fields of the base plan at the front so that when there are references @@ -107,7 +108,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; - // this alias map is resolved and looked up in both having exprs and group by exprs + // This alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); // Optionally the HAVING expression. @@ -137,16 +138,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .transpose()?; - // The outer expressions we will search through for - // aggregates. Aggregates may be sourced from the SELECT... - let mut aggr_expr_haystack = select_exprs.clone(); - // ... or from the HAVING. - if let Some(having_expr) = &having_expr_opt { - aggr_expr_haystack.push(having_expr.clone()); - } - + // The outer expressions we will search through for aggregates. + // Aggregates may be sourced from the SELECT list or from the HAVING expression. + let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter()); // All of the aggregate expressions (deduplicated). - let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); // All of the group by expressions let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by { @@ -159,7 +155,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - // aliases from the projection can conflict with same-named expressions in the input + // Aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); for f in base_plan.schema().fields() { alias_map.remove(f.name()); @@ -192,7 +188,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect() }; - // process group by, aggregation or having + // Process group by, aggregation or having let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs .is_empty() || !aggr_exprs.is_empty() @@ -219,7 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // process window function + // Process window function let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); let plan = if window_func_exprs.is_empty() { @@ -227,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let plan = LogicalPlanBuilder::window_plan(plan, window_func_exprs.clone())?; - // re-write the projection + // Re-write the projection select_exprs_post_aggr = select_exprs_post_aggr .iter() .map(|expr| rebase_expr(expr, &window_func_exprs, &plan)) @@ -236,10 +232,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // try process unnest expression or do the final projection + // Try processing unnest expression or do the final projection let plan = self.try_process_unnest(plan, select_exprs_post_aggr)?; - // process distinct clause + // Process distinct clause let plan = match select.distinct { None => Ok(plan), Some(Distinct::Distinct) => { @@ -304,9 +300,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration - // Only exaust the loop if no more unnest transformation is found + // Only exhaust the loop if no more unnest transformation is found for i in 0.. { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); // from which column used for projection, before the unnest happen // including non unnest column and unnest column let mut inner_projection_exprs = vec![]; @@ -334,14 +330,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { break; } else { // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); - + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } let plan = LogicalPlanBuilder::from(intermediate_plan) .project(inner_projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_plan = plan; intermediate_select_exprs = outer_projection_exprs; @@ -390,7 +399,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .. } = agg; - // process unnest of group_by_exprs, and input of agg will be rewritten + // Process unnest of group_by_exprs, and input of agg will be rewritten // for example: // // ``` @@ -410,7 +419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut intermediate_select_exprs = group_expr; loop { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; let outer_projection_exprs = rewrite_recursive_unnests_bottom_up( @@ -423,7 +432,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if unnest_columns.is_empty() { break; } else { - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); let mut projection_exprs = match &aggr_expr_using_columns { Some(exprs) => (*exprs).clone(), @@ -445,12 +454,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; projection_exprs.extend(inner_projection_exprs); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } + intermediate_plan = LogicalPlanBuilder::from(intermediate_plan) .project(projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_select_exprs = outer_projection_exprs; @@ -477,6 +501,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; + + // Check for aggregation functions + let aggregate_exprs = + find_aggregate_exprs(std::slice::from_ref(&filter_expr)); + if !aggregate_exprs.is_empty() { + return plan_err!( + "Aggregate functions are not allowed in the WHERE clause. Consider using HAVING instead" + ); + } + let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 29dfe25993f13..abb9912b712a7 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -30,14 +30,15 @@ use crate::planner::{ use crate::utils::normalize_ident; use arrow_schema::{DataType, Fields}; +use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, - unqualified_field_not_found, Column, Constraints, DFSchema, DFSchemaRef, + unqualified_field_not_found, Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -53,7 +54,7 @@ use datafusion_expr::{ TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; -use sqlparser::ast; +use sqlparser::ast::{self, SqliteOnConflict}; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, @@ -98,7 +99,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::Unique { + } => constraints.push(TableConstraint::Unique { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -110,7 +111,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::PrimaryKey { + } => constraints.push(TableConstraint::PrimaryKey { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -124,7 +125,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::ForeignKey { + } => constraints.push(TableConstraint::ForeignKey { name: name.clone(), columns: vec![], foreign_table: foreign_table.clone(), @@ -134,7 +135,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - constraints.push(ast::TableConstraint::Check { + constraints.push(TableConstraint::Check { name: name.clone(), expr: Box::new(expr.clone()), }) @@ -393,13 +394,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Build column default values let column_defaults = self.build_column_defaults(&columns, planner_context)?; + + let has_columns = !columns.is_empty(); + let schema = self.build_schema(columns)?.to_dfschema_ref()?; + if has_columns { + planner_context.set_table_schema(Some(Arc::clone(&schema))); + } + match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; let input_schema = plan.schema(); - let plan = if !columns.is_empty() { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let plan = if has_columns { if schema.fields().len() != input_schema.fields().len() { return plan_err!( "Mismatch: {} columns specified, but result has {} columns", @@ -427,7 +434,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -440,18 +447,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } None => { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; let plan = EmptyRelation { produce_one_row: false, schema, }; let plan = LogicalPlan::EmptyRelation(plan); - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -463,6 +470,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } @@ -498,9 +506,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if if_not_exists { return not_impl_err!("If not exists not supported")?; } - if temporary { - return not_impl_err!("Temporary views not supported")?; - } if to.is_some() { return not_impl_err!("To not supported")?; } @@ -526,6 +531,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), or_replace, definition: sql, + temporary, }))) } Statement::ShowCreate { obj_type, obj_name } => match obj_type { @@ -665,12 +671,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, ignore, table_alias, - replace_into, + mut replace_into, priority, insert_alias, }) => { - if or.is_some() { - plan_err!("Inserts with or clauses not supported")?; + if let Some(or) = or { + match or { + SqliteOnConflict::Replace => replace_into = true, + _ => plan_err!("Inserts with {or} clause is not supported")?, + } } if partitioned.is_some() { plan_err!("Partitioned inserts not yet supported")?; @@ -698,9 +707,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Inserts with a table alias not supported: {table_alias:?}" )? }; - if replace_into { - plan_err!("Inserts with a `REPLACE INTO` clause not supported")? - }; if let Some(priority) = priority { plan_err!( "Inserts with a `PRIORITY` clause not supported: {priority:?}" @@ -710,7 +716,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Inserts with an alias not supported")?; } let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source, overwrite) + self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } Statement::Update { table, @@ -770,7 +776,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let isolation_level: ast::TransactionIsolationLevel = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(_) => None, TransactionMode::IsolationLevel(level) => Some(level), }) @@ -779,7 +785,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(ast::TransactionIsolationLevel::Serializable); let access_mode: ast::TransactionAccessMode = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(mode) => Some(mode), TransactionMode::IsolationLevel(_) => None, }) @@ -878,14 +884,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => None, }; - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), [..] => not_impl_err!("Qualified functions are not supported")?, }; // - // convert resulting expression to data fusion expression + // Convert resulting expression to data fusion expression // let arg_types = args.as_ref().map(|arg| { arg.iter().map(|t| t.data_type.clone()).collect::>() @@ -933,10 +939,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { func_desc, .. } => { - // according to postgresql documentation it can be only one function + // According to postgresql documentation it can be only one function // specified in drop statement if let Some(desc) = func_desc.first() { - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &desc.name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), @@ -1028,7 +1034,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter: Option, ) -> Result { if self.has_table("information_schema", "tables") { - // we only support the basic "SHOW TABLES" + // We only support the basic "SHOW TABLES" // https://github.com/apache/datafusion/issues/3188 if db_name.is_some() || filter.is_some() || full || extended { plan_err!("Unsupported parameters to SHOW TABLES") @@ -1059,7 +1065,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn copy_to_plan(&self, statement: CopyToStatement) -> Result { - // determine if source is table or query and handle accordingly + // Determine if source is table or query and handle accordingly let copy_source = statement.source; let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { @@ -1100,7 +1106,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .to_string(), ) }; - // try to infer file format from file extension + // Try to infer file format from file extension let extension: &str = &Path::new(&statement.target) .extension() .ok_or_else(e)? @@ -1198,6 +1204,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { location, table_partition_cols, if_not_exists, + temporary, order_exprs, unbounded, options, @@ -1239,10 +1246,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let ordered_exprs = self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; - // External tables do not support schemas at the moment, so the name is just a table name - let name = TableReference::bare(name); + let name = self.object_name_to_table_reference(name)?; let constraints = - Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; + Self::new_constraint_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, @@ -1251,6 +1257,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { file_type, table_partition_cols, if_not_exists, + temporary, definition, order_exprs: ordered_exprs, unbounded, @@ -1261,6 +1268,74 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + /// Convert each [TableConstraint] to corresponding [Constraint] + fn new_constraint_from_table_constraints( + constraints: &[TableConstraint], + df_schema: &DFSchemaRef, + ) -> Result { + let constraints = constraints + .iter() + .map(|c: &TableConstraint| match c { + TableConstraint::Unique { name, columns, .. } => { + let field_names = df_schema.field_names(); + // Get unique constraint indices in the schema: + let indices = columns + .iter() + .map(|u| { + let idx = field_names + .iter() + .position(|item| *item == u.value) + .ok_or_else(|| { + let name = name + .as_ref() + .map(|name| format!("with name '{name}' ")) + .unwrap_or("".to_string()); + DataFusionError::Execution( + format!("Column for unique constraint {}not found in schema: {}", name,u.value) + ) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::Unique(indices)) + } + TableConstraint::PrimaryKey { columns, .. } => { + let field_names = df_schema.field_names(); + // Get primary key indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = field_names + .iter() + .position(|item| *item == pk.value) + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Column for primary key not found in schema: {}", + pk.value + )) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::PrimaryKey(indices)) + } + TableConstraint::ForeignKey { .. } => { + _plan_err!("Foreign key constraints are not currently supported") + } + TableConstraint::Check { .. } => { + _plan_err!("Check constraints are not currently supported") + } + TableConstraint::Index { .. } => { + _plan_err!("Indexes are not currently supported") + } + TableConstraint::FulltextOrSpatial { .. } => { + _plan_err!("Indexes are not currently supported") + } + }) + .collect::>>()?; + Ok(Constraints::new_unverified(constraints)) + } + fn parse_options_map( &self, options: Vec<(String, Value)>, @@ -1407,11 +1482,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut variable_lower = variable.to_lowercase(); if variable_lower == "timezone" || variable_lower == "time.zone" { - // we could introduce alias in OptionDefinition if this string matching thing grows + // We could introduce alias in OptionDefinition if this string matching thing grows variable_lower = "datafusion.execution.time_zone".to_string(); } - // parse value string from Expr + // Parse value string from Expr let value_string = match &value[0] { SQLExpr::Identifier(i) => ident_to_string(i), SQLExpr::Value(v) => match crate::utils::value_to_string(v) { @@ -1420,7 +1495,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Some(v) => v, }, - // for capture signed number e.g. +8, -8 + // For capture signed number e.g. +8, -8 SQLExpr::UnaryOp { op, expr } => match op { UnaryOperator::Plus => format!("+{expr}"), UnaryOperator::Minus => format!("-{expr}"), @@ -1575,7 +1650,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => { // If the target table has an alias, use it to qualify the column name if let Some(alias) = &table_alias { - datafusion_expr::Expr::Column(Column::new( + Expr::Column(Column::new( Some(self.ident_normalizer.normalize(alias.name.clone())), field.name(), )) @@ -1605,6 +1680,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { columns: Vec, source: Box, overwrite: bool, + replace_into: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; @@ -1614,10 +1690,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Get insert fields and target table's value indices // - // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // If value_indices[i] = Some(j), it means that the value of the i-th target table's column is // derived from the j-th output of the source. // - // if value_indices[i] = None, it means that the value of the i-th target table's column is + // If value_indices[i] = None, it means that the value of the i-th target table's column is // not provided, and should be filled with a default value later. let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table @@ -1707,16 +1783,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; let source = project(source, exprs)?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto + let insert_op = match (overwrite, replace_into) { + (false, false) => InsertOp::Append, + (true, false) => InsertOp::Overwrite, + (false, true) => InsertOp::Replace, + (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; let plan = LogicalPlan::Dml(DmlStatement::new( table_name, Arc::new(table_schema), - op, + WriteOp::Insert(insert_op), Arc::new(source), )); Ok(plan) @@ -1748,7 +1825,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_ref = self.object_name_to_table_reference(sql_table_name)?; let _ = self.context_provider.get_table_source(table_ref)?; - // treat both FULL and EXTENDED as the same + // Treat both FULL and EXTENDED as the same let select_list = if full || extended { "*" } else { diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 71ff712985cdb..2de1ce9125a7d 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -182,7 +182,28 @@ impl SelectBuilder { self } pub fn selection(&mut self, value: Option) -> &mut Self { - self.selection = value; + // With filter pushdown optimization, the LogicalPlan can have filters defined as part of `TableScan` and `Filter` nodes. + // To avoid overwriting one of the filters, we combine the existing filter with the additional filter. + // Example: | + // | Projection: customer.c_phone AS cntrycode, customer.c_acctbal | + // | Filter: CAST(customer.c_acctbal AS Decimal128(38, 6)) > () | + // | Subquery: + // | .. | + // | TableScan: customer, full_filters=[customer.c_mktsegment = Utf8("BUILDING")] + match (&self.selection, value) { + (Some(existing_selection), Some(new_selection)) => { + self.selection = Some(ast::Expr::BinaryOp { + left: Box::new(existing_selection.clone()), + op: ast::BinaryOperator::And, + right: Box::new(new_selection), + }); + } + (None, Some(new_selection)) => { + self.selection = Some(new_selection); + } + (_, None) => (), + } + self } pub fn group_by(&mut self, value: ast::GroupByExpr) -> &mut Self { diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index d8a4fb2542643..88159ab6df15c 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -18,12 +18,17 @@ use std::sync::Arc; use arrow_schema::TimeUnit; +use datafusion_expr::Expr; use regex::Regex; use sqlparser::{ - ast::{self, Ident, ObjectName, TimezoneInfo}, + ast::{self, Function, Ident, ObjectName, TimezoneInfo}, keywords::ALL_KEYWORDS, }; +use datafusion_common::Result; + +use super::{utils::date_part_to_sql, Unparser}; + /// `Dialect` to use for Unparsing /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) @@ -54,8 +59,8 @@ pub trait Dialect: Send + Sync { /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Double + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::Double } /// The SQL type to use for Arrow Utf8 unparsing @@ -81,6 +86,12 @@ pub trait Dialect: Send + Sync { ast::DataType::BigInt(None) } + /// The SQL type to use for Arrow Int32 unparsing + /// Most dialects use Integer, but some, like MySQL, require SIGNED + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Integer(None) + } + /// The SQL type to use for Timestamp unparsing /// Most dialects use Timestamp, but some, like MySQL, require Datetime /// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp @@ -99,8 +110,8 @@ pub trait Dialect: Send + Sync { /// The SQL type to use for Arrow Date32 unparsing /// Most dialects use Date, but some, like SQLite require TEXT - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Date + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Date } /// Does the dialect support specifying column aliases as part of alias table definition? @@ -108,6 +119,24 @@ pub trait Dialect: Send + Sync { fn supports_column_alias_in_table_alias(&self) -> bool { true } + + /// Whether the dialect requires a table alias for any subquery in the FROM clause + /// This affects behavior when deriving logical plans for Sort, Limit, etc. + fn requires_derived_table_alias(&self) -> bool { + false + } + + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn scalar_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } } /// `IntervalStyle` to use for unparsing @@ -145,7 +174,7 @@ impl Dialect for DefaultDialect { fn identifier_quote_style(&self, identifier: &str) -> Option { let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap(); let id_upper = identifier.to_uppercase(); - // special case ignore "ID", see https://github.com/sqlparser-rs/sqlparser-rs/issues/1382 + // Special case ignore "ID", see https://github.com/sqlparser-rs/sqlparser-rs/issues/1382 // ID is a keyword in ClickHouse, but we don't want to quote it when unparsing SQL here if (id_upper != "ID" && ALL_KEYWORDS.contains(&id_upper.as_str())) || !identifier_regex.is_match(identifier) @@ -168,8 +197,69 @@ impl Dialect for PostgreSqlDialect { IntervalStyle::PostgresVerbose } - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::DoublePrecision + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::DoublePrecision + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "round" { + return Ok(Some( + self.round_to_sql_enforce_numeric(unparser, func_name, args)?, + )); + } + + Ok(None) + } +} + +impl PostgreSqlDialect { + fn round_to_sql_enforce_numeric( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result { + let mut args = unparser.function_args_to_sql(args)?; + + // Enforce the first argument to be Numeric + if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) = + args.first_mut() + { + if let ast::Expr::Cast { data_type, .. } = expr { + // Don't create an additional cast wrapper if we can update the existing one + *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); + } else { + // Wrap the expression in a new cast + *expr = ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(expr.clone()), + data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + format: None, + }; + } + } + + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) } } @@ -204,6 +294,10 @@ impl Dialect for MySqlDialect { ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) } + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + fn timestamp_cast_dtype( &self, _time_unit: &TimeUnit, @@ -211,6 +305,23 @@ impl Dialect for MySqlDialect { ) -> ast::DataType { ast::DataType::Datetime(None) } + + fn requires_derived_table_alias(&self) -> bool { + true + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct SqliteDialect {} @@ -224,13 +335,26 @@ impl Dialect for SqliteDialect { DateFieldExtractStyle::Strftime } - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Text + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text } fn supports_column_alias_in_table_alias(&self) -> bool { false } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct CustomDialect { @@ -238,15 +362,17 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - float64_ast_dtype: sqlparser::ast::DataType, + float64_ast_dtype: ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, - date32_cast_dtype: sqlparser::ast::DataType, + date32_cast_dtype: ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialect { @@ -256,24 +382,26 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, - float64_ast_dtype: sqlparser::ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), timestamp_tz_cast_dtype: ast::DataType::Timestamp( None, TimezoneInfo::WithTimeZone, ), - date32_cast_dtype: sqlparser::ast::DataType::Date, + date32_cast_dtype: ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } } impl CustomDialect { - // create a CustomDialect + // Create a CustomDialect #[deprecated(note = "please use `CustomDialectBuilder` instead")] pub fn new(identifier_quote_style: Option) -> Self { Self { @@ -300,7 +428,7 @@ impl Dialect for CustomDialect { self.interval_style } - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + fn float64_ast_dtype(&self) -> ast::DataType { self.float64_ast_dtype.clone() } @@ -320,6 +448,10 @@ impl Dialect for CustomDialect { self.int64_cast_dtype.clone() } + fn int32_cast_dtype(&self) -> ast::DataType { + self.int32_cast_dtype.clone() + } + fn timestamp_cast_dtype( &self, _time_unit: &TimeUnit, @@ -332,13 +464,30 @@ impl Dialect for CustomDialect { } } - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { + fn date32_cast_dtype(&self) -> ast::DataType { self.date32_cast_dtype.clone() } fn supports_column_alias_in_table_alias(&self) -> bool { self.supports_column_alias_in_table_alias } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } + + fn requires_derived_table_alias(&self) -> bool { + self.requires_derived_table_alias + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -360,15 +509,17 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - float64_ast_dtype: sqlparser::ast::DataType, + float64_ast_dtype: ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialectBuilder { @@ -384,18 +535,20 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, - float64_ast_dtype: sqlparser::ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), timestamp_tz_cast_dtype: ast::DataType::Timestamp( None, TimezoneInfo::WithTimeZone, ), - date32_cast_dtype: sqlparser::ast::DataType::Date, + date32_cast_dtype: ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } @@ -410,11 +563,13 @@ impl CustomDialectBuilder { large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, int64_cast_dtype: self.int64_cast_dtype, + int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, date32_cast_dtype: self.date32_cast_dtype, supports_column_alias_in_table_alias: self .supports_column_alias_in_table_alias, + requires_derived_table_alias: self.requires_derived_table_alias, } } @@ -424,7 +579,7 @@ impl CustomDialectBuilder { self } - /// Customize the dialect to supports `NULLS FIRST` in `ORDER BY` clauses + /// Customize the dialect to support `NULLS FIRST` in `ORDER BY` clauses pub fn with_supports_nulls_first_in_sort( mut self, supports_nulls_first_in_sort: bool, @@ -449,10 +604,7 @@ impl CustomDialectBuilder { } /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. - pub fn with_float64_ast_dtype( - mut self, - float64_ast_dtype: sqlparser::ast::DataType, - ) -> Self { + pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; self } @@ -487,6 +639,12 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific SQL type for Int32 casting: Integer, SIGNED, etc. + pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> Self { + self.int32_cast_dtype = int32_cast_dtype; + self + } + /// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc. pub fn with_timestamp_cast_dtype( mut self, @@ -503,7 +661,7 @@ impl CustomDialectBuilder { self } - /// Customize the dialect to supports column aliases as part of alias table definition + /// Customize the dialect to support column aliases as part of alias table definition pub fn with_supports_column_alias_in_table_alias( mut self, supports_column_alias_in_table_alias: bool, @@ -511,4 +669,12 @@ impl CustomDialectBuilder { self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias; self } + + pub fn with_requires_derived_table_alias( + mut self, + requires_derived_table_alias: bool, + ) -> Self { + self.requires_derived_table_alias = requires_derived_table_alias; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index b924268a7657f..b41b0a54b86f0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::ScalarUDF; +use datafusion_expr::expr::Unnest; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, - ObjectName, TimezoneInfo, UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + TimezoneInfo, UnaryOperator, }; use std::sync::Arc; use std::vec; -use super::dialect::{DateFieldExtractStyle, IntervalStyle}; +use super::dialect::IntervalStyle; use super::Unparser; use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; use arrow::util::display::array_value_to_string; @@ -77,13 +77,8 @@ pub fn expr_to_sql(expr: &Expr) -> Result { unparser.expr_to_sql(expr) } -pub fn sort_to_sql(sort: &Sort) -> Result { - let unparser = Unparser::default(); - unparser.sort_to_sql(sort) -} - const LOWEST: &BinaryOperator = &BinaryOperator::Or; -// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs // (https://www.postgresql.org/docs/7.2/sql-precedence.html) const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; @@ -116,47 +111,14 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); - if let Some(expr) = - self.scalar_function_to_sql_overrides(func_name, func, args) + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? { return Ok(expr); } - let args = args - .iter() - .map(|e| { - if matches!( - e, - Expr::Wildcard { - qualifier: None, - .. - } - ) { - Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) - } else { - self.expr_to_sql_inner(e).map(|e| { - FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - }) - } - }) - .collect::>>()?; - - Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { - value: func_name.to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, - args, - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })) + self.scalar_function_to_sql(func_name, args) } Expr::Between(Between { expr, @@ -263,9 +225,10 @@ impl Unparser<'_> { ast::WindowFrameUnits::Groups } }; - let order_by: Vec = order_by + + let order_by = order_by .iter() - .map(sort_to_sql) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>()?; let start_bound = self.convert_bound(&window_frame.start_bound)?; @@ -285,7 +248,7 @@ impl Unparser<'_> { })); Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -329,7 +292,7 @@ impl Unparser<'_> { None => None, }; Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -504,10 +467,34 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), - Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::Unnest(unnest) => self.unnest_to_sql(unnest), } } + pub fn scalar_function_to_sql( + &self, + func_name: &str, + args: &[Expr], + ) -> Result { + let args = self.function_args_to_sql(args)?; + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -530,90 +517,9 @@ impl Unparser<'_> { }) } - fn scalar_function_to_sql_overrides( - &self, - func_name: &str, - _func: &Arc, - args: &[Expr], - ) -> Option { - if func_name.to_lowercase() == "date_part" { - match (self.dialect.date_field_extract_style(), args.len()) { - (DateFieldExtractStyle::Extract, 2) => { - let date_expr = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => ast::DateTimeField::Year, - "month" => ast::DateTimeField::Month, - "day" => ast::DateTimeField::Day, - "hour" => ast::DateTimeField::Hour, - "minute" => ast::DateTimeField::Minute, - "second" => ast::DateTimeField::Second, - _ => return None, - }; - - return Some(ast::Expr::Extract { - field, - expr: Box::new(date_expr), - syntax: ast::ExtractSyntax::From, - }); - } - } - (DateFieldExtractStyle::Strftime, 2) => { - let column = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => "%Y", - "month" => "%m", - "day" => "%d", - "hour" => "%H", - "minute" => "%M", - "second" => "%S", - _ => return None, - }; - - return Some(ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident { - value: "strftime".to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List( - ast::FunctionArgumentList { - duplicate_treatment: None, - args: vec![ - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(ast::Expr::Value( - ast::Value::SingleQuotedString( - field.to_string(), - ), - )), - ), - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(column), - ), - ], - clauses: vec![], - }, - ), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })); - } - } - _ => {} // no overrides for DateFieldExtractStyle::DatePart, because it's already a date_part - } - } - - None - } - fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { - ast::DataType::Timestamp(None, ast::TimezoneInfo::None) + ast::DataType::Timestamp(None, TimezoneInfo::None) } else { ast::DataType::Datetime(None) } @@ -665,7 +571,10 @@ impl Unparser<'_> { } } - fn function_args_to_sql(&self, args: &[Expr]) -> Result> { + pub(crate) fn function_args_to_sql( + &self, + args: &[Expr], + ) -> Result> { args.iter() .map(|e| { if matches!( @@ -685,16 +594,16 @@ impl Unparser<'_> { } /// This function can create an identifier with or without quotes based on the dialect rules - pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> ast::Ident { + pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> Ident { let quote_style = self.dialect.identifier_quote_style(&ident); - ast::Ident { + Ident { value: ident, quote_style, } } - pub(super) fn new_ident_without_quote_style(&self, str: String) -> ast::Ident { - ast::Ident { + pub(super) fn new_ident_without_quote_style(&self, str: String) -> Ident { + Ident { value: str, quote_style: None, } @@ -704,7 +613,7 @@ impl Unparser<'_> { &self, lhs: ast::Expr, rhs: ast::Expr, - op: ast::BinaryOperator, + op: BinaryOperator, ) -> ast::Expr { ast::Expr::BinaryOp { left: Box::new(lhs), @@ -786,10 +695,10 @@ impl Unparser<'_> { match expr { ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), - // closest precedence we currently have to Between is PGLikeMatch + // Closest precedence we currently have to Between is PGLikeMatch // (https://www.postgresql.org/docs/7.2/sql-precedence.html) ast::Expr::Between { .. } => { - self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) + self.sql_op_precedence(&BinaryOperator::PGLikeMatch) } _ => 0, } @@ -819,70 +728,70 @@ impl Unparser<'_> { fn sql_to_op(&self, op: &BinaryOperator) -> Result { match op { - ast::BinaryOperator::Eq => Ok(Operator::Eq), - ast::BinaryOperator::NotEq => Ok(Operator::NotEq), - ast::BinaryOperator::Lt => Ok(Operator::Lt), - ast::BinaryOperator::LtEq => Ok(Operator::LtEq), - ast::BinaryOperator::Gt => Ok(Operator::Gt), - ast::BinaryOperator::GtEq => Ok(Operator::GtEq), - ast::BinaryOperator::Plus => Ok(Operator::Plus), - ast::BinaryOperator::Minus => Ok(Operator::Minus), - ast::BinaryOperator::Multiply => Ok(Operator::Multiply), - ast::BinaryOperator::Divide => Ok(Operator::Divide), - ast::BinaryOperator::Modulo => Ok(Operator::Modulo), - ast::BinaryOperator::And => Ok(Operator::And), - ast::BinaryOperator::Or => Ok(Operator::Or), - ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), - ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), - ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), - ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), - ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), - ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), - ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), - ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), - ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), - ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), - ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), - ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), - ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), - ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), - ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), - ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + BinaryOperator::Eq => Ok(Operator::Eq), + BinaryOperator::NotEq => Ok(Operator::NotEq), + BinaryOperator::Lt => Ok(Operator::Lt), + BinaryOperator::LtEq => Ok(Operator::LtEq), + BinaryOperator::Gt => Ok(Operator::Gt), + BinaryOperator::GtEq => Ok(Operator::GtEq), + BinaryOperator::Plus => Ok(Operator::Plus), + BinaryOperator::Minus => Ok(Operator::Minus), + BinaryOperator::Multiply => Ok(Operator::Multiply), + BinaryOperator::Divide => Ok(Operator::Divide), + BinaryOperator::Modulo => Ok(Operator::Modulo), + BinaryOperator::And => Ok(Operator::And), + BinaryOperator::Or => Ok(Operator::Or), + BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + BinaryOperator::StringConcat => Ok(Operator::StringConcat), + BinaryOperator::AtArrow => Ok(Operator::AtArrow), + BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), _ => not_impl_err!("unsupported operation: {op:?}"), } } - fn op_to_sql(&self, op: &Operator) -> Result { + fn op_to_sql(&self, op: &Operator) -> Result { match op { - Operator::Eq => Ok(ast::BinaryOperator::Eq), - Operator::NotEq => Ok(ast::BinaryOperator::NotEq), - Operator::Lt => Ok(ast::BinaryOperator::Lt), - Operator::LtEq => Ok(ast::BinaryOperator::LtEq), - Operator::Gt => Ok(ast::BinaryOperator::Gt), - Operator::GtEq => Ok(ast::BinaryOperator::GtEq), - Operator::Plus => Ok(ast::BinaryOperator::Plus), - Operator::Minus => Ok(ast::BinaryOperator::Minus), - Operator::Multiply => Ok(ast::BinaryOperator::Multiply), - Operator::Divide => Ok(ast::BinaryOperator::Divide), - Operator::Modulo => Ok(ast::BinaryOperator::Modulo), - Operator::And => Ok(ast::BinaryOperator::And), - Operator::Or => Ok(ast::BinaryOperator::Or), + Operator::Eq => Ok(BinaryOperator::Eq), + Operator::NotEq => Ok(BinaryOperator::NotEq), + Operator::Lt => Ok(BinaryOperator::Lt), + Operator::LtEq => Ok(BinaryOperator::LtEq), + Operator::Gt => Ok(BinaryOperator::Gt), + Operator::GtEq => Ok(BinaryOperator::GtEq), + Operator::Plus => Ok(BinaryOperator::Plus), + Operator::Minus => Ok(BinaryOperator::Minus), + Operator::Multiply => Ok(BinaryOperator::Multiply), + Operator::Divide => Ok(BinaryOperator::Divide), + Operator::Modulo => Ok(BinaryOperator::Modulo), + Operator::And => Ok(BinaryOperator::And), + Operator::Or => Ok(BinaryOperator::Or), Operator::IsDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), Operator::IsNotDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), - Operator::RegexMatch => Ok(ast::BinaryOperator::PGRegexMatch), - Operator::RegexIMatch => Ok(ast::BinaryOperator::PGRegexIMatch), - Operator::RegexNotMatch => Ok(ast::BinaryOperator::PGRegexNotMatch), - Operator::RegexNotIMatch => Ok(ast::BinaryOperator::PGRegexNotIMatch), - Operator::ILikeMatch => Ok(ast::BinaryOperator::PGILikeMatch), - Operator::NotLikeMatch => Ok(ast::BinaryOperator::PGNotLikeMatch), - Operator::LikeMatch => Ok(ast::BinaryOperator::PGLikeMatch), - Operator::NotILikeMatch => Ok(ast::BinaryOperator::PGNotILikeMatch), - Operator::BitwiseAnd => Ok(ast::BinaryOperator::BitwiseAnd), - Operator::BitwiseOr => Ok(ast::BinaryOperator::BitwiseOr), - Operator::BitwiseXor => Ok(ast::BinaryOperator::BitwiseXor), - Operator::BitwiseShiftRight => Ok(ast::BinaryOperator::PGBitwiseShiftRight), - Operator::BitwiseShiftLeft => Ok(ast::BinaryOperator::PGBitwiseShiftLeft), - Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat), + Operator::RegexMatch => Ok(BinaryOperator::PGRegexMatch), + Operator::RegexIMatch => Ok(BinaryOperator::PGRegexIMatch), + Operator::RegexNotMatch => Ok(BinaryOperator::PGRegexNotMatch), + Operator::RegexNotIMatch => Ok(BinaryOperator::PGRegexNotIMatch), + Operator::ILikeMatch => Ok(BinaryOperator::PGILikeMatch), + Operator::NotLikeMatch => Ok(BinaryOperator::PGNotLikeMatch), + Operator::LikeMatch => Ok(BinaryOperator::PGLikeMatch), + Operator::NotILikeMatch => Ok(BinaryOperator::PGNotILikeMatch), + Operator::BitwiseAnd => Ok(BinaryOperator::BitwiseAnd), + Operator::BitwiseOr => Ok(BinaryOperator::BitwiseOr), + Operator::BitwiseXor => Ok(BinaryOperator::BitwiseXor), + Operator::BitwiseShiftRight => Ok(BinaryOperator::PGBitwiseShiftRight), + Operator::BitwiseShiftLeft => Ok(BinaryOperator::PGBitwiseShiftLeft), + Operator::StringConcat => Ok(BinaryOperator::StringConcat), Operator::AtArrow => not_impl_err!("unsupported operation: {op:?}"), Operator::ArrowAt => not_impl_err!("unsupported operation: {op:?}"), } @@ -1026,17 +935,17 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) } ScalarValue::UInt64(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8View(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8View(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::LargeUtf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), @@ -1069,7 +978,7 @@ impl Unparser<'_> { Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + expr: Box::new(ast::Expr::Value(SingleQuotedString( date.to_string(), ))), data_type: ast::DataType::Date, @@ -1092,7 +1001,7 @@ impl Unparser<'_> { Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + expr: Box::new(ast::Expr::Value(SingleQuotedString( datetime.to_string(), ))), data_type: self.ast_type_for_date64_in_cast(), @@ -1229,7 +1138,7 @@ impl Unparser<'_> { return Ok(ast::Expr::Interval(interval)); } - // calculate the best single interval to represent the provided days and microseconds + // Calculate the best single interval to represent the provided days and microseconds let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 1_000_000); @@ -1334,9 +1243,9 @@ impl Unparser<'_> { IntervalStyle::SQLStandard => match v { ScalarValue::IntervalYearMonth(Some(v)) => { let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(v.to_string()), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.to_string(), + ))), leading_field: Some(ast::DateTimeField::Month), leading_precision: None, last_field: None, @@ -1355,11 +1264,9 @@ impl Unparser<'_> { let millis = v.milliseconds % 1_000; let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(format!( - "{days} {hours}:{mins}:{secs}.{millis:3}" - )), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString(format!( + "{days} {hours}:{mins}:{secs}.{millis:3}" + )))), leading_field: Some(ast::DateTimeField::Day), leading_precision: None, last_field: Some(ast::DateTimeField::Second), @@ -1370,9 +1277,9 @@ impl Unparser<'_> { ScalarValue::IntervalMonthDayNano(Some(v)) => { if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(v.months.to_string()), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.months.to_string(), + ))), leading_field: Some(ast::DateTimeField::Month), leading_precision: None, last_field: None, @@ -1391,11 +1298,9 @@ impl Unparser<'_> { let millis = (v.nanoseconds % 1_000_000_000) / 1_000_000; let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(format!( - "{days} {hours}:{mins}:{secs}.{millis:03}" - )), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + format!("{days} {hours}:{mins}:{secs}.{millis:03}"), + ))), leading_field: Some(ast::DateTimeField::Day), leading_precision: None, last_field: Some(ast::DateTimeField::Second), @@ -1432,6 +1337,29 @@ impl Unparser<'_> { } } + /// Converts an UNNEST operation to an AST expression by wrapping it as a function call, + /// since there is no direct representation for UNNEST in the AST. + fn unnest_to_sql(&self, unnest: &Unnest) -> Result { + let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?; + + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: "UNNEST".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { @@ -1440,7 +1368,7 @@ impl Unparser<'_> { DataType::Boolean => Ok(ast::DataType::Bool), DataType::Int8 => Ok(ast::DataType::TinyInt(None)), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), - DataType::Int32 => Ok(ast::DataType::Integer(None)), + DataType::Int32 => Ok(self.dialect.int32_cast_dtype()), DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), @@ -1554,7 +1482,10 @@ mod tests { use datafusion_functions_aggregate::expr_fn::sum; use datafusion_functions_window::row_number::row_number_udwf; - use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder}; + use crate::unparser::dialect::{ + CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, + PostgreSqlDialect, + }; use super::*; @@ -1944,6 +1875,15 @@ mod tests { }), r#"CAST(a AS DECIMAL(12,0))"#, ), + ( + Expr::Unnest(Unnest { + expr: Box::new(Expr::Column(Column { + relation: Some(TableReference::partial("schema", "table")), + name: "array_col".to_string(), + })), + }), + r#"UNNEST("schema"."table".array_col)"#, + ), ]; for (expr, expected) in tests { @@ -2018,11 +1958,8 @@ mod tests { #[test] fn custom_dialect_float64_ast_dtype() -> Result<()> { for (float64_ast_dtype, identifier) in [ - (sqlparser::ast::DataType::Double, "DOUBLE"), - ( - sqlparser::ast::DataType::DoublePrecision, - "DOUBLE PRECISION", - ), + (ast::DataType::Double, "DOUBLE"), + (ast::DataType::DoublePrecision, "DOUBLE PRECISION"), ] { let dialect = CustomDialectBuilder::new() .with_float64_ast_dtype(float64_ast_dtype) @@ -2338,6 +2275,34 @@ mod tests { Ok(()) } + #[test] + fn custom_dialect_with_int32_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int32_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int32, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + #[test] fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> { let default_dialect = CustomDialectBuilder::new().build(); @@ -2411,10 +2376,7 @@ mod tests { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( "variation".to_string(), )))), - data_type: DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), + data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", )]; @@ -2428,4 +2390,39 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn test_round_scalar_fn_to_expr() -> Result<()> { + let default_dialect: Arc = Arc::new( + CustomDialectBuilder::new() + .with_identifier_quote_style('"') + .build(), + ); + let postgres_dialect: Arc = Arc::new(PostgreSqlDialect {}); + + for (dialect, identifier) in + [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] + { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from( + datafusion_functions::math::round::RoundFunc::new(), + )), + args: vec![ + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }), + Expr::Literal(ScalarValue::Int64(Some(2))), + ], + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); + + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a76e26aa7d989..2c38a1d36c1ea 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,19 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::unparser::utils::unproject_agg_exprs; -use datafusion_common::{ - internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, - Column, DataFusionError, Result, TableReference, -}; -use datafusion_expr::{ - expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, -}; -use sqlparser::ast::{self, Ident, SetExpr}; -use std::sync::Arc; - use super::{ ast::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, @@ -38,9 +25,25 @@ use super::{ rewrite_plan_for_sort_on_non_projected_fields, subquery_alias_inner_query_and_columns, TableAliasRewriter, }, - utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, + utils::{ + find_agg_node_within_select, find_unnest_node_within_select, + find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters, + unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs, + }, Unparser, }; +use crate::unparser::utils::unproject_agg_exprs; +use datafusion_common::{ + internal_err, not_impl_err, + tree_node::{TransformedResult, TreeNode}, + Column, DataFusionError, Result, TableReference, +}; +use datafusion_expr::{ + expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, +}; +use sqlparser::ast::{self, Ident, SetExpr}; +use std::sync::Arc; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// @@ -94,7 +97,6 @@ impl Unparser<'_> { | LogicalPlan::Aggregate(_) | LogicalPlan::Sort(_) | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -172,13 +174,26 @@ impl Unparser<'_> { p: &Projection, select: &mut SelectBuilder, ) -> Result<()> { - match find_agg_node_within_select(plan, None, true) { - Some(AggVariant::Aggregate(agg)) => { - let items = p - .expr - .iter() + let mut exprs = p.expr.clone(); + + // If an Unnest node is found within the select, find and unproject the unnest column + if let Some(unnest) = find_unnest_node_within_select(plan) { + exprs = exprs + .into_iter() + .map(|e| unproject_unnest_expr(e, unnest)) + .collect::>>()?; + }; + + match ( + find_agg_node_within_select(plan, true), + find_window_nodes_within_select(plan, None, true), + ) { + (Some(agg), window) => { + let window_option = window.as_deref(); + let items = exprs + .into_iter() .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; + let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; self.select_item_to_sql(&unproj) }) .collect::>>()?; @@ -192,10 +207,9 @@ impl Unparser<'_> { vec![], )); } - Some(AggVariant::Window(window)) => { - let items = p - .expr - .iter() + (None, Some(window)) => { + let items = exprs + .into_iter() .map(|proj_expr| { let unproj = unproject_window_exprs(proj_expr, &window)?; self.select_item_to_sql(&unproj) @@ -204,9 +218,8 @@ impl Unparser<'_> { select.projection(items); } - None => { - let items = p - .expr + _ => { + let items = exprs .iter() .map(|e| self.select_item_to_sql(e)) .collect::>>()?; @@ -216,9 +229,14 @@ impl Unparser<'_> { Ok(()) } - fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> { + fn derive( + &self, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + alias: Option, + ) -> Result<()> { let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ + derived_builder.lateral(false).alias(alias).subquery({ let inner_statement = self.plan_to_sql(plan)?; if let ast::Statement::Query(inner_query) = inner_statement { inner_query @@ -233,6 +251,23 @@ impl Unparser<'_> { Ok(()) } + fn derive_with_dialect_alias( + &self, + alias: &str, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + ) -> Result<()> { + if self.dialect.requires_derived_table_alias() { + self.derive( + plan, + relation, + Some(self.new_table_alias(alias.to_string(), vec![])), + ) + } else { + self.derive(plan, relation, None) + } + } + fn select_to_sql_recursively( &self, plan: &LogicalPlan, @@ -242,12 +277,9 @@ impl Unparser<'_> { ) -> Result<()> { match plan { LogicalPlan::TableScan(scan) => { - if scan.projection.is_some() - || !scan.filters.is_empty() - || scan.fetch.is_some() + if let Some(unparsed_table_scan) = + Self::unparse_table_scan_pushdown(plan, None)? { - let unparsed_table_scan = - Self::unparse_table_scan_pushdown(plan, None)?; return self.select_to_sql_recursively( &unparsed_table_scan, query, @@ -281,16 +313,21 @@ impl Unparser<'_> { // Projection can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_projection", + plan, + relation, + ); } self.reconstruct_select_statement(plan, p, select)?; self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { - if let Some(AggVariant::Aggregate(agg)) = - find_agg_node_within_select(plan, None, select.already_projected()) + if let Some(agg) = + find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; + let unprojected = + unproject_agg_exprs(filter.predicate.clone(), agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { @@ -308,21 +345,22 @@ impl Unparser<'_> { LogicalPlan::Limit(limit) => { // Limit can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_limit", + plan, + relation, + ); } - if let Some(fetch) = limit.fetch { + if let Some(fetch) = &limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( "Limit operator only valid in a statement context." ); }; - query.limit(Some(ast::Expr::Value(ast::Value::Number( - fetch.to_string(), - false, - )))); + query.limit(Some(self.expr_to_sql(fetch)?)); } - if limit.skip > 0 { + if let Some(skip) = &limit.skip { let Some(query) = query.as_mut() else { return internal_err!( "Offset operator only valid in a statement context." @@ -330,10 +368,7 @@ impl Unparser<'_> { }; query.offset(Some(ast::Offset { rows: ast::OffsetRows::None, - value: ast::Expr::Value(ast::Value::Number( - limit.skip.to_string(), - false, - )), + value: self.expr_to_sql(skip)?, })); } @@ -347,15 +382,36 @@ impl Unparser<'_> { LogicalPlan::Sort(sort) => { // Sort can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_sort", + plan, + relation, + ); } - if let Some(query_ref) = query { - query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); - } else { + let Some(query_ref) = query else { return internal_err!( "Sort operator only valid in a statement context." ); - } + }; + + if let Some(fetch) = sort.fetch { + query_ref.limit(Some(ast::Expr::Value(ast::Value::Number( + fetch.to_string(), + false, + )))); + }; + + let agg = find_agg_node_within_select(plan, select.already_projected()); + // unproject sort expressions + let sort_exprs: Vec = sort + .expr + .iter() + .map(|sort_expr| { + unproject_sort_expr(sort_expr, agg, sort.input.as_ref()) + }) + .collect::>>()?; + + query_ref.order_by(self.sorts_to_sql(&sort_exprs)?); self.select_to_sql_recursively( sort.input.as_ref(), @@ -376,7 +432,11 @@ impl Unparser<'_> { LogicalPlan::Distinct(distinct) => { // Distinct can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_distinct", + plan, + relation, + ); } let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), @@ -393,7 +453,7 @@ impl Unparser<'_> { .collect::>>()?; if let Some(sort_expr) = &on.sort_expr { if let Some(query_ref) = query { - query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort_expr)?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -408,55 +468,77 @@ impl Unparser<'_> { self.select_to_sql_recursively(input, query, select, relation) } LogicalPlan::Join(join) => { - let join_constraint = self.join_constraint_to_sql( - join.join_constraint, - &join.on, - join.filter.as_ref(), - )?; + let mut table_scan_filters = vec![]; - let mut right_relation = RelationBuilder::default(); + let left_plan = + match try_transform_to_simple_table_scan_with_filters(&join.left)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.left), + }; self.select_to_sql_recursively( - join.left.as_ref(), + left_plan.as_ref(), query, select, relation, )?; + + let right_plan = + match try_transform_to_simple_table_scan_with_filters(&join.right)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.right), + }; + + let mut right_relation = RelationBuilder::default(); + self.select_to_sql_recursively( - join.right.as_ref(), + right_plan.as_ref(), query, select, &mut right_relation, )?; - let Ok(Some(relation)) = right_relation.build() else { - return internal_err!("Failed to build right relation"); - }; + let join_filters = if table_scan_filters.is_empty() { + join.filter.clone() + } else { + // Combine `table_scan_filters` into a single filter using `AND` + let Some(combined_filters) = + table_scan_filters.into_iter().reduce(|acc, filter| { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(acc), + op: Operator::And, + right: Box::new(filter), + }) + }) + else { + return internal_err!("Failed to combine TableScan filters"); + }; - let ast_join = ast::Join { - relation, - global: false, - join_operator: self - .join_operator_to_sql(join.join_type, join_constraint), + // Combine `join.filter` with `combined_filters` using `AND` + match &join.filter { + Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + left: Box::new(filter.clone()), + op: Operator::And, + right: Box::new(combined_filters), + })), + None => Some(combined_filters), + } }; - let mut from = select.pop_from().unwrap(); - from.push_join(ast_join); - select.push_from(from); - - Ok(()) - } - LogicalPlan::CrossJoin(cross_join) => { - // Cross joins are the same as unconditional inner joins - let mut right_relation = RelationBuilder::default(); - self.select_to_sql_recursively( - cross_join.left.as_ref(), - query, - select, - relation, + let join_constraint = self.join_constraint_to_sql( + join.join_constraint, + &join.on, + join_filters.as_ref(), )?; + self.select_to_sql_recursively( - cross_join.right.as_ref(), + right_plan.as_ref(), query, select, &mut right_relation, @@ -469,12 +551,8 @@ impl Unparser<'_> { let ast_join = ast::Join { relation, global: false, - join_operator: self.join_operator_to_sql( - JoinType::Inner, - ast::JoinConstraint::On(ast::Expr::Value(ast::Value::Boolean( - true, - ))), - ), + join_operator: self + .join_operator_to_sql(join.join_type, join_constraint), }; let mut from = select.pop_from().unwrap(); from.push_join(ast_join); @@ -485,10 +563,18 @@ impl Unparser<'_> { LogicalPlan::SubqueryAlias(plan_alias) => { let (plan, mut columns) = subquery_alias_inner_query_and_columns(plan_alias); - let plan = Self::unparse_table_scan_pushdown( + let unparsed_table_scan = Self::unparse_table_scan_pushdown( plan, Some(plan_alias.alias.clone()), )?; + // if the child plan is a TableScan with pushdown operations, we don't need to + // create an additional subquery for it + if !select.already_projected() && unparsed_table_scan.is_none() { + select.projection(vec![ast::SelectItem::Wildcard( + ast::WildcardAdditionalOptions::default(), + )]); + } + let plan = unparsed_table_scan.unwrap_or_else(|| plan.clone()); if !columns.is_empty() && !self.dialect.supports_column_alias_in_table_alias() { @@ -529,6 +615,15 @@ impl Unparser<'_> { ); } + // Covers cases where the UNION is a subquery and the projection is at the top level + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_union", + plan, + relation, + ); + } + let input_exprs: Vec = union .inputs .iter() @@ -565,19 +660,51 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), + LogicalPlan::Unnest(unnest) => { + if !unnest.struct_type_columns.is_empty() { + return internal_err!( + "Struct type columns are not currently supported in UNNEST: {:?}", + unnest.struct_type_columns + ); + } + + // In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip. + // Otherwise, there will be a duplicate SELECT clause. + // | Projection: table.col1, UNNEST(table.col2) + // | Unnest: UNNEST(table.col2) + // | Projection: table.col1, table.col2 AS UNNEST(table.col2) + // | Filter: table.col3 = Int64(3) + // | TableScan: table projection=None + if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + // continue with projection input + self.select_to_sql_recursively(&p.input, query, select, relation) + } else { + internal_err!("Unnest input is not a Projection: {unnest:?}") + } + } _ => not_impl_err!("Unsupported operator: {plan:?}"), } } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { + scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() + } + + /// Try to unparse a table scan with pushdown operations into a new subquery plan. + /// If the table scan is without any pushdown operations, return None. fn unparse_table_scan_pushdown( plan: &LogicalPlan, alias: Option, - ) -> Result { + ) -> Result> { match plan { LogicalPlan::TableScan(table_scan) => { + if !Self::is_scan_with_pushdown(table_scan) { + return Ok(None); + } + let table_schema = table_scan.source.schema(); let mut filter_alias_rewriter = alias.as_ref().map(|alias_name| TableAliasRewriter { - table_schema: table_scan.source.schema(), + table_schema: &table_schema, alias_name: alias_name.clone(), }); @@ -586,6 +713,17 @@ impl Unparser<'_> { Arc::clone(&table_scan.source), None, )?; + // We will rebase the column references to the new alias if it exists. + // If the projection or filters are empty, we will append alias to the table scan. + // + // Example: + // select t1.c1 from t1 where t1.c1 > 1 -> select a.c1 from t1 as a where a.c1 > 1 + if let Some(ref alias) = alias { + if table_scan.projection.is_some() || !table_scan.filters.is_empty() { + builder = builder.alias(alias.clone())?; + } + } + if let Some(project_vec) = &table_scan.projection { let project_columns = project_vec .iter() @@ -603,9 +741,6 @@ impl Unparser<'_> { } }) .collect::>(); - if let Some(alias) = alias { - builder = builder.alias(alias)?; - } builder = builder.project(project_columns)?; } @@ -635,18 +770,59 @@ impl Unparser<'_> { builder = builder.limit(0, Some(fetch))?; } - builder.build() + // If the table scan has an alias but no projection or filters, it means no column references are rebased. + // So we will append the alias to this subquery. + // Example: + // select * from t1 limit 10 -> (select * from t1 limit 10) as a + if let Some(alias) = alias { + if table_scan.projection.is_none() && table_scan.filters.is_empty() { + builder = builder.alias(alias)?; + } + } + + Ok(Some(builder.build()?)) } LogicalPlan::SubqueryAlias(subquery_alias) => { - let new_plan = Self::unparse_table_scan_pushdown( + Self::unparse_table_scan_pushdown( &subquery_alias.input, Some(subquery_alias.alias.clone()), - )?; - LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.alias.clone())? - .build() + ) + } + // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. + // The inner table scan could be a scan with pushdown operations. + LogicalPlan::Projection(projection) => { + if let Some(plan) = + Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? + { + let exprs = if alias.is_some() { + let mut alias_rewriter = + alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: plan.schema().as_arrow(), + alias_name: alias_name.clone(), + }); + projection + .expr + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::>>()? + } else { + projection.expr.clone() + }; + Ok(Some( + LogicalPlanBuilder::from(plan).project(exprs)?.build()?, + )) + } else { + Ok(None) + } } - _ => Ok(plan.clone()), + _ => Ok(None), } } @@ -668,7 +844,7 @@ impl Unparser<'_> { } } - fn sorts_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result> { sort_exprs .iter() .map(|sort_expr| self.sort_to_sql(sort_expr)) diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 9b4eaca834f16..57d700f86955f 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,7 +20,7 @@ use std::{ sync::Arc, }; -use arrow_schema::SchemaRef; +use arrow_schema::Schema; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, Result, TableReference, @@ -101,25 +101,25 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { Ok(sort_exprs) } -// Rewrite logic plan for query that order by columns are not in projections -// Plan before rewrite: -// -// Projection: j1.j1_string, j2.j2_string -// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST -// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id -// Inner Join: Filter: j1.j1_id = j2.j2_id -// TableScan: j1 -// TableScan: j2 -// -// Plan after rewrite -// -// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST -// Projection: j1.j1_string, j2.j2_string -// Inner Join: Filter: j1.j1_id = j2.j2_id -// TableScan: j1 -// TableScan: j2 -// -// This prevents the original plan generate query with derived table but missing alias. +/// Rewrite logic plan for query that order by columns are not in projections +/// Plan before rewrite: +/// +/// Projection: j1.j1_string, j2.j2_string +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// Plan after rewrite +/// +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// This prevents the original plan generate query with derived table but missing alias. pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( p: &Projection, ) -> Option { @@ -191,33 +191,33 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( } } -// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of -// subquery -// - `(SELECT column_a as a from table) AS A` -// - `(SELECT column_a from table) AS A (a)` -// -// A roundtrip example for table alias with columns -// -// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) -// -// LogicPlan: -// Projection: c.id -// SubqueryAlias: c -// Projection: j1.j1_id AS id -// Projection: j1.j1_id -// TableScan: j1 -// -// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS -// id FROM (SELECT j1.j1_id FROM j1)) AS c`. -// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table -// `(SELECT j1.j1_id FROM j1)` -// -// With this logic, the unparsed query will be: -// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` -// -// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` -// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and -// Column in the Projections. Once the parser side is fixed, this logic should work +/// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +/// subquery +/// - `(SELECT column_a as a from table) AS A` +/// - `(SELECT column_a from table) AS A (a)` +/// +/// A roundtrip example for table alias with columns +/// +/// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +/// +/// LogicPlan: +/// Projection: c.id +/// SubqueryAlias: c +/// Projection: j1.j1_id AS id +/// Projection: j1.j1_id +/// TableScan: j1 +/// +/// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +/// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +/// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +/// `(SELECT j1.j1_id FROM j1)` +/// +/// With this logic, the unparsed query will be: +/// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +/// +/// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +/// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +/// Column in the Projections. Once the parser side is fixed, this logic should work pub(super) fn subquery_alias_inner_query_and_columns( subquery_alias: &datafusion_expr::SubqueryAlias, ) -> (&LogicalPlan, Vec) { @@ -227,13 +227,13 @@ pub(super) fn subquery_alias_inner_query_and_columns( return (plan, vec![]); }; - // check if it's projection inside projection + // Check if it's projection inside projection let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else { return (plan, vec![]); }; let mut columns: Vec = vec![]; - // check if the inner projection and outer projection have a matching pattern like + // Check if the inner projection and outer projection have a matching pattern like // Projection: j1.j1_id AS id // Projection: j1.j1_id for (i, inner_expr) in inner_projection.expr.iter().enumerate() { @@ -241,7 +241,7 @@ pub(super) fn subquery_alias_inner_query_and_columns( return (plan, vec![]); }; - // inner projection schema fields store the projection name which is used in outer + // Inner projection schema fields store the projection name which is used in outer // projection expr let inner_expr_string = match inner_expr { Expr::Column(_) => inner_expr.to_string(), @@ -293,7 +293,7 @@ pub(super) fn inject_column_aliases_into_subquery( /// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to /// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table` pub(super) fn inject_column_aliases( - projection: &datafusion_expr::Projection, + projection: &Projection, aliases: impl IntoIterator, ) -> LogicalPlan { let mut updated_projection = projection.clone(); @@ -330,6 +330,7 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { _ => None, } } + /// A `TreeNodeRewriter` implementation that rewrites `Expr::Column` expressions by /// replacing the column's name with an alias if the column exists in the provided schema. /// @@ -342,12 +343,12 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { /// from which the columns are referenced. This is used to look up columns by their names. /// * `alias_name`: The alias (`TableReference`) that will replace the table name /// in the column references when applicable. -pub struct TableAliasRewriter { - pub table_schema: SchemaRef, +pub struct TableAliasRewriter<'a> { + pub table_schema: &'a Schema, pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter { +impl TreeNodeRewriter for TableAliasRewriter<'_> { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c1b3fe18f7e70..284956cef195e 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,85 +15,168 @@ // specific language governing permissions and limitations // under the License. +use std::{cmp::Ordering, sync::Arc, vec}; + use datafusion_common::{ internal_err, - tree_node::{Transformed, TreeNode}, - Result, + tree_node::{Transformed, TransformedResult, TreeNode}, + Column, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, + LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; +use sqlparser::ast; -/// One of the possible aggregation plans which can be found within a single select query. -pub(crate) enum AggVariant<'a> { - Aggregate(&'a Aggregate), - Window(Vec<&'a Window>), +use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; + +/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists +/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). +/// If an Aggregate or node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_agg_node_within_select( + plan: &LogicalPlan, + already_projected: bool, +) -> Option<&Aggregate> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + // Agg nodes explicitly return immediately with a single node + if let LogicalPlan::Aggregate(agg) = input { + Some(agg) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + if already_projected { + None + } else { + find_agg_node_within_select(input, true) + } + } else { + find_agg_node_within_select(input, already_projected) + } } -/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists +/// Recursively searches children of [LogicalPlan] to find Unnest node if exist +pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + + if let LogicalPlan::Unnest(unnest) = input { + Some(unnest) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + None + } else { + find_unnest_node_within_select(input) + } +} + +/// Recursively searches children of [LogicalPlan] to find Window nodes if exist /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). -/// If an Aggregate or window node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both -/// be found in a single select query. -pub(crate) fn find_agg_node_within_select<'a>( +/// If Window node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_window_nodes_within_select<'a>( plan: &'a LogicalPlan, - mut prev_windows: Option>, + mut prev_windows: Option>, already_projected: bool, -) -> Option> { - // Note that none of the nodes that have a corresponding agg node can have more +) -> Option> { + // Note that none of the nodes that have a corresponding node can have more // than 1 input node. E.g. Projection / Filter always have 1 input node. let input = plan.inputs(); let input = if input.len() > 1 { - return None; + return prev_windows; } else { input.first()? }; - // Agg nodes explicitly return immediately with a single node // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection match input { - LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)), LogicalPlan::Window(window) => { prev_windows = match &mut prev_windows { - Some(AggVariant::Window(windows)) => { + Some(windows) => { windows.push(window); prev_windows } - _ => Some(AggVariant::Window(vec![window])), + _ => Some(vec![window]), }; - find_agg_node_within_select(input, prev_windows, already_projected) + find_window_nodes_within_select(input, prev_windows, already_projected) } LogicalPlan::Projection(_) => { if already_projected { prev_windows } else { - find_agg_node_within_select(input, prev_windows, true) + find_window_nodes_within_select(input, prev_windows, true) } } LogicalPlan::TableScan(_) => prev_windows, - _ => find_agg_node_within_select(input, prev_windows, already_projected), + _ => find_window_nodes_within_select(input, prev_windows, already_projected), } } +/// Recursively identify Column expressions and transform them into the appropriate unnest expression +/// +/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) +pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(col_ref) = &sub_expr { + // Check if the column is among the columns to run unnest on. + // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if let Ok(idx) = unnest.schema.index_of_column(col_ref) { + if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let Some(unprojected_expr) = expr.get(idx) { + let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + return Ok(Transformed::yes(unnest_expr)); + } + } + } + return internal_err!( + "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name + ); + } + } + + Ok(Transformed::no(sub_expr)) + + }).map(|e| e.data) +} + /// Recursively identify all Column expressions and transform them into the appropriate /// aggregate expression contained in agg. /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. -pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result { - expr.clone() - .transform(|sub_expr| { +pub(crate) fn unproject_agg_exprs( + expr: Expr, + agg: &Aggregate, + windows: Option<&[&Window]>, +) -> Result { + expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schema - if let Ok(n) = agg.schema.index_of_column(&c) { - let unprojected_expr = agg - .group_expr - .iter() - .chain(agg.aggr_expr.iter()) - .nth(n) - .unwrap(); + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)); } else { internal_err!( - "Tried to unproject agg expr not found in provided Aggregate!" + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name ) } } else { @@ -108,22 +191,257 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result /// /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. -pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result { - expr.clone() - .transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unproj) = windows +pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unproj) = find_window_expr(windows, &c.name) { + Ok(Transformed::yes(unproj.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) +} + +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { + if let Ok(index) = agg.schema.index_of_column(column) { + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + // For grouping set expr, we must operate by expression list from the grouping set + let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; + match index.cmp(&grouping_expr.len()) { + Ordering::Less => Ok(grouping_expr.into_iter().nth(index)), + Ordering::Equal => { + internal_err!( + "Tried to unproject column referring to internal grouping id" + ) + } + Ordering::Greater => { + Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1)) + } + } + } else { + Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) + } + } else { + Ok(None) + } +} + +fn find_window_expr<'a>( + windows: &'a [&'a Window], + column_name: &'a str, +) -> Option<&'a Expr> { + windows + .iter() + .flat_map(|w| w.window_expr.iter()) + .find(|expr| expr.schema_name().to_string() == column_name) +} + +/// Transforms a Column expression into the actual expression from aggregation or projection if found. +/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced +/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to +/// the actual expression, such as sum("catalog_returns"."cr_net_loss"). +pub(crate) fn unproject_sort_expr( + sort_expr: &SortExpr, + agg: Option<&Aggregate>, + input: &LogicalPlan, +) -> Result { + let mut sort_expr = sort_expr.clone(); + + // Remove alias if present, because ORDER BY cannot use aliases + if let Expr::Alias(alias) = &sort_expr.expr { + sort_expr.expr = *alias.expr.clone(); + } + + let Expr::Column(ref col_ref) = sort_expr.expr else { + return Ok(sort_expr); + }; + + if col_ref.relation.is_some() { + return Ok(sort_expr); + }; + + // In case of aggregation there could be columns containing aggregation functions we need to unproject + if let Some(agg) = agg { + if agg.schema.is_column_from_schema(col_ref) { + let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; + sort_expr.expr = new_expr; + return Ok(sort_expr); + } + } + + // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will + // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need + // to transform it back to the actual expression. + if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { + if let Ok(idx) = schema.index_of_column(col_ref) { + if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { + sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + } + } + return Ok(sort_expr); + } + + Ok(sort_expr) +} + +/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering +/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan. +/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately. +/// If the plan contains a Projection, returns None. +/// +/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias. +/// +/// LogicalPlan example: +/// Filter: ta.j1_id < 5 +/// Alias: ta +/// TableScan: j1, j1_id > 10 +/// +/// Will return LogicalPlan below: +/// Alias: ta +/// TableScan: j1 +/// And filters: [ta.j1_id < 5, ta.j1_id > 10] +pub(crate) fn try_transform_to_simple_table_scan_with_filters( + plan: &LogicalPlan, +) -> Result)>> { + let mut filters: Vec = vec![]; + let mut plan_stack = vec![plan]; + let mut table_alias = None; + + while let Some(current_plan) = plan_stack.pop() { + match current_plan { + LogicalPlan::SubqueryAlias(alias) => { + table_alias = Some(alias.alias.clone()); + plan_stack.push(alias.input.as_ref()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); + plan_stack.push(filter.input.as_ref()); + } + LogicalPlan::TableScan(table_scan) => { + let table_schema = table_scan.source.schema(); + // optional rewriter if table has an alias + let mut filter_alias_rewriter = + table_alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: &table_schema, + alias_name: alias_name.clone(), + }); + + // rewrite filters to use table alias if present + let table_scan_filters = table_scan + .filters .iter() - .flat_map(|w| w.window_expr.iter()) - .find(|window_expr| window_expr.schema_name().to_string() == c.name) - { - Ok(Transformed::yes(unproj.clone())) - } else { - Ok(Transformed::no(Expr::Column(c))) + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = filter_alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::, DataFusionError>>()?; + + filters.extend(table_scan_filters); + + let mut builder = LogicalPlanBuilder::scan( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + None, + )?; + + if let Some(alias) = table_alias.take() { + builder = builder.alias(alias)?; } - } else { - Ok(Transformed::no(sub_expr)) + + let plan = builder.build()?; + + return Ok(Some((plan, filters))); } - }) - .map(|e| e.data) + _ => { + return Ok(None); + } + } + } + + Ok(None) +} + +/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. +pub(crate) fn date_part_to_sql( + unparser: &Unparser, + style: DateFieldExtractStyle, + date_part_args: &[Expr], +) -> Result> { + match (style, date_part_args.len()) { + (DateFieldExtractStyle::Extract, 2) => { + let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + syntax: ast::ExtractSyntax::From, + })); + } + } + (DateFieldExtractStyle::Strftime, 2) => { + let column = unparser.expr_to_sql(&date_part_args[1])?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%m", + "day" => "%d", + "hour" => "%H", + "minute" => "%M", + "second" => "%S", + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: "strftime".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + ast::Expr::Value(ast::Value::SingleQuotedString( + field.to_string(), + )), + )), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + }))); + } + } + (DateFieldExtractStyle::DatePart, _) => { + return Ok(Some( + unparser.scalar_function_to_sql("date_part", date_part_args)?, + )); + } + _ => {} + }; + + Ok(None) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 656e4b851aa86..14436de018437 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -34,9 +34,9 @@ use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ - col, expr_vec_fmt, ColumnUnnestList, ColumnUnnestType, Expr, ExprSchemable, - LogicalPlan, + col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, }; +use indexmap::IndexMap; use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree @@ -203,7 +203,7 @@ pub(crate) fn resolve_aliases_to_exprs( .data() } -/// given a slice of window expressions sharing the same sort key, find their common partition +/// Given a slice of window expressions sharing the same sort key, find their common partition /// keys. pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> { let all_partition_keys = window_exprs @@ -295,7 +295,7 @@ pub(crate) fn value_to_string(value: &Value) -> Option { pub(crate) fn rewrite_recursive_unnests_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_exprs: &[Expr], ) -> Result> { @@ -322,11 +322,11 @@ A full example of how the transformation works: struct RecursiveUnnestRewriter<'a> { input_schema: &'a DFSchemaRef, root_expr: &'a Expr, - // useful to detect which child expr is a part of/ not a part of unnest operation + // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, consecutive_unnest: Vec>, inner_projection_exprs: &'a mut Vec, - columns_unnestings: &'a mut Vec<(Column, ColumnUnnestType)>, + columns_unnestings: &'a mut IndexMap>>, transformed_root_exprs: Option>, } impl<'a> RecursiveUnnestRewriter<'a> { @@ -360,13 +360,11 @@ impl<'a> RecursiveUnnestRewriter<'a> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - // let placeholder_name = unnest_expr.display_name()?; let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); let post_unnest_name = format!("unnest_placeholder({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by - // let post_unnest_alias = print_unnest(&inner_expr_name, level); let placeholder_column = Column::from_name(placeholder_name.clone()); let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; @@ -380,10 +378,8 @@ impl<'a> RecursiveUnnestRewriter<'a> { self.inner_projection_exprs, expr_in_unnest.clone().alias(placeholder_name.clone()), ); - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::Struct, - )); + self.columns_unnestings + .insert(Column::from_name(placeholder_name.clone()), None); Ok( get_struct_unnested_columns(&placeholder_name, &inner_fields) .into_iter() @@ -399,39 +395,18 @@ impl<'a> RecursiveUnnestRewriter<'a> { expr_in_unnest.clone().alias(placeholder_name.clone()), ); - // let post_unnest_column = Column::from_name(post_unnest_name); let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name); - match self + let list_unnesting = self .columns_unnestings - .iter_mut() - .find(|(inner_col, _)| inner_col == &placeholder_column) - { - // there is not unnesting done on this column yet - None => { - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }]), - )); - } - // some unnesting(at some level) has been done on this column - // e.g select unnest(column3), unnest(unnest(column3)) - Some((_, unnesting)) => match unnesting { - ColumnUnnestType::List(list) => { - let unnesting = ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }; - if !list.contains(&unnesting) { - list.push(unnesting); - } - } - _ => { - return internal_err!("not reached"); - } - }, + .entry(placeholder_column) + .or_insert(Some(vec![])); + let unnesting = ColumnUnnestList { + output_column: Column::from_name(post_unnest_name), + depth: level, + }; + let list_unnestings = list_unnesting.as_mut().unwrap(); + if !list_unnestings.contains(&unnesting) { + list_unnestings.push(unnesting); } Ok(vec![post_unnest_expr]) } @@ -478,8 +453,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { } /// The rewriting only happens when the traversal has reached the top-most unnest expr - /// within a sequence of consecutive unnest exprs. - /// node, for example given a stack of expr + /// within a sequence of consecutive unnest exprs node /// /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))** /// ```text @@ -512,7 +486,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; } - // find inside consecutive_unnest, the sequence of continous unnest exprs + // Find inside consecutive_unnest, the sequence of continous unnest exprs // Get the latest consecutive unnest exprs // and check if current upward traversal is the returning to the root expr @@ -560,7 +534,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { // For column exprs that are not descendants of any unnest node // retain their projection // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b - // this condition can be checked by maintaining an Option + // this condition can be checked by maintaining an Option if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() { push_projection_dedupl(self.inner_projection_exprs, expr.clone()); } @@ -589,7 +563,7 @@ fn push_projection_dedupl(projection: &mut Vec, expr: Expr) { /// is done only for the bottom expression pub(crate) fn rewrite_recursive_unnest_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_expr: &Expr, ) -> Result> { @@ -610,8 +584,8 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102 // // The transformation looks like: - // - unnest(array_col) will be transformed into unnest(array_col) - // - unnest(array_col) + 1 will be transformed into unnest(array_col) + 1 + // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)") + // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1") let Transformed { data: transformed_expr, transformed, @@ -619,7 +593,9 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( } = original_expr.clone().rewrite(&mut rewriter)?; if !transformed { - if matches!(&transformed_expr, Expr::Column(_)) { + if matches!(&transformed_expr, Expr::Column(_)) + || matches!(&transformed_expr, Expr::Wildcard { .. }) + { push_projection_dedupl(inner_projection_exprs, transformed_expr.clone()); Ok(vec![transformed_expr]) } else { @@ -645,17 +621,33 @@ mod tests { use arrow_schema::Fields; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::{ - col, lit, unnest, ColumnUnnestType, EmptyRelation, LogicalPlan, + col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::expr_fn::count; + use indexmap::IndexMap; use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up}; - fn column_unnests_eq(l: Vec<(&str, &str)>, r: &[(Column, ColumnUnnestType)]) { - let r_formatted: Vec = - r.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); - let l_formatted: Vec = - l.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); + + fn column_unnests_eq( + l: Vec<&str>, + r: &IndexMap>>, + ) { + let r_formatted: Vec = r + .iter() + .map(|i| match i.1 { + None => format!("{}", i.0), + Some(vec) => format!( + "{}=>[{}]", + i.0, + vec.iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ), + }) + .collect(); + let l_formatted: Vec = l.iter().map(|i| i.to_string()).collect(); assert_eq!(l_formatted, r_formatted); } @@ -685,7 +677,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(unnest(3d_col)) + unnest(unnest(3d_col)) @@ -698,7 +690,7 @@ mod tests { &mut inner_projection_exprs, &original_expr, )?; - // only the bottom most unnest exprs are transformed + // Only the bottom most unnest exprs are transformed assert_eq!( transformed_exprs, vec![col("unnest_placeholder(3d_col,depth=2)") @@ -710,14 +702,13 @@ mod tests { .add(col("i64_col"))] ); column_unnests_eq( - vec![( - "unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + ], &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -744,12 +735,10 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1])"), - ], + vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -792,7 +781,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(struct_col) @@ -811,10 +800,10 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(struct_col)", "Struct")], + vec!["unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -831,15 +820,12 @@ mod tests { )?; column_unnests_eq( vec![ - ("unnest_placeholder(struct_col)", "Struct"), - ( - "unnest_placeholder(array_col)", - "List([unnest_placeholder(array_col,depth=1)|depth=1])", - ), + "unnest_placeholder(struct_col)", + "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); - // only transform the unnest children + // Only transform the unnest children assert_eq!( transformed_exprs, vec![col("unnest_placeholder(array_col,depth=1)") @@ -847,8 +833,8 @@ mod tests { .add(lit(1i64))] ); - // keep appending to the current vector - // still reference array_col in original schema but with alias, + // Keep appending to the current vector + // Still reference array_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -858,24 +844,44 @@ mod tests { ] ); - // a nested structure struct[[]] + Ok(()) + } + + // Unnest -> field access -> unnest + #[test] + fn test_transform_non_consecutive_unnests() -> Result<()> { + // List of struct + // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}] let schema = Schema::new(vec![ Field::new( - "struct_col", // {array_col: [1,2,3]} - ArrowDataType::Struct(Fields::from(vec![Field::new( - "matrix", - ArrowDataType::List(Arc::new(Field::new( - "matrix_row", - ArrowDataType::List(Arc::new(Field::new( - "item", - ArrowDataType::Int64, + "struct_list", + ArrowDataType::List(Arc::new(Field::new( + "element", + ArrowDataType::Struct(Fields::from(vec![ + Field::new( + // list of i64 + "subfield1", + ArrowDataType::List(Arc::new(Field::new( + "i64_element", + ArrowDataType::Int64, + true, + ))), true, - ))), - true, - ))), + ), + Field::new( + // list of utf8 + "subfield2", + ArrowDataType::List(Arc::new(Field::new( + "utf8_element", + ArrowDataType::Utf8, + true, + ))), + true, + ), + ])), true, - )])), - false, + ))), + true, ), Field::new("int_col", ArrowDataType::Int32, false), ]); @@ -887,39 +893,69 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // An expr with multiple unnest - let original_expr = unnest(unnest(col("struct_col").field("matrix"))); + let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1")); let transformed_exprs = rewrite_recursive_unnest_bottom_up( &input, &mut unnest_placeholder_columns, &mut inner_projection_exprs, - &original_expr, + &select_expr1, )?; // Only the inner most/ bottom most unnest is transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(struct_col[matrix],depth=2)") - .alias("UNNEST(UNNEST(struct_col[matrix]))")] + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield1") + )] ); - // TODO: add a test case where - // unnest -> field access -> unnest column_unnests_eq( - vec![( - "unnest_placeholder(struct_col[matrix])", - "List([unnest_placeholder(struct_col[matrix],depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + + assert_eq!( + inner_projection_exprs, + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + ); + + // continue rewrite another expr in select + let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &select_expr2, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield2") + )] + ); + + // unnest place holder columns remain the same + // because expr1 and expr2 derive from the same unnest result + column_unnests_eq( + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_col") - .field("matrix") - .alias("unnest_placeholder(struct_col[matrix])"),] + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] ); Ok(()) diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index 9efb75bd60e43..a4001bea7deac 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; @@ -31,16 +33,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { rows, } = values; - // values should not be based on any other schema - let schema = DFSchema::empty(); + let empty_schema = Arc::new(DFSchema::empty()); let values = rows .into_iter() .map(|row| { row.into_iter() - .map(|v| self.sql_to_expr(v, &schema, planner_context)) + .map(|v| self.sql_to_expr(v, &empty_schema, planner_context)) .collect::>>() }) .collect::>>()?; - LogicalPlanBuilder::values(values)?.build() + + let schema = planner_context.table_schema().unwrap_or(empty_schema); + if schema.fields().is_empty() { + LogicalPlanBuilder::values(values)?.build() + } else { + LogicalPlanBuilder::values_with_schema(values, &schema)?.build() + } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 49f4720ed1374..ea0ccb8e4b43e 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -22,6 +22,10 @@ use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; +use datafusion_functions::unicode; +use datafusion_functions_aggregate::grouping::grouping_udaf; +use datafusion_functions_nested::make_array::make_array_udf; +use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, @@ -71,7 +75,7 @@ fn roundtrip_expr() { let ast = expr_to_sql(&expr)?; - Ok(format!("{}", ast)) + Ok(ast.to_string()) }; for (table, query, expected) in tests { @@ -139,6 +143,13 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, + r#"SELECT col1, id FROM ( + SELECT j1_string AS col1, j1_id AS id FROM j1 + UNION ALL + SELECT j2_string AS col1, j2_id AS id FROM j2 + UNION ALL + SELECT j3_string AS col1, j3_id AS id FROM j3 + ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#, "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", @@ -149,6 +160,26 @@ fn roundtrip_statement() -> Result<()> { "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum + FROM person + JOIN orders ON person.id = orders.customer_id + GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#, + r#"SELECT id, first_name, last_name, + SUM(id) AS total_sum, + COUNT(*) AS total_count, + SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total + FROM person + GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#, ]; // For each test sql string, we transform as follows: @@ -164,6 +195,7 @@ fn roundtrip_statement() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(sum_udaf()) .with_aggregate_function(count_udaf()) + .with_aggregate_function(max_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); @@ -171,7 +203,7 @@ fn roundtrip_statement() -> Result<()> { let roundtrip_statement = plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -203,7 +235,7 @@ fn roundtrip_crossjoin() -> Result<()> { let roundtrip_statement = plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -212,11 +244,11 @@ fn roundtrip_crossjoin() -> Result<()> { .unwrap(); let expected = "Projection: j1.j1_id, j2.j2_string\ - \n Inner Join: Filter: Boolean(true)\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; - assert_eq!(format!("{plan_roundtrip}"), expected); + assert_eq!(plan_roundtrip.to_string(), expected); Ok(()) } @@ -230,6 +262,45 @@ fn roundtrip_statement_with_dialect() -> Result<()> { unparser_dialect: Box, } let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort gets derived into a subquery + // for MySQL, this subquery needs an alias + "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort still gets derived into a subquery in default dialect + // except for the default dialect, the subquery is left non-aliased + "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", + expected: + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select 1 as j1_id);", + expected: + "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select * from (select * from j1 limit 10);", + expected: + "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", expected: @@ -457,7 +528,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -487,7 +558,7 @@ Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placehold Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col TableScan: unnest_table"#.trim_start(); - assert_eq!(format!("{plan}"), expected); + assert_eq!(plan.to_string(), expected); Ok(()) } @@ -507,7 +578,7 @@ fn test_table_references_in_plan_to_sql() { .unwrap(); let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(format!("{}", sql), expected_sql) + assert_eq!(sql.to_string(), expected_sql) } test( @@ -537,7 +608,7 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() { .build() .unwrap(); let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(format!("{}", sql), expected_sql) + assert_eq!(sql.to_string(), expected_sql) } test( @@ -636,7 +707,13 @@ where .unwrap(); let context = MockContextProvider { - state: MockSessionState::default(), + state: MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()), }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -646,27 +723,103 @@ where } #[test] -fn test_table_scan_pushdown() -> Result<()> { +fn test_table_scan_alias() -> Result<()> { let schema = Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("age", DataType::Utf8, false), ]); + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .filter(col("id").gt(lit(5)))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!( + sql.to_string(), + "SELECT * FROM (SELECT t1.id FROM t1 WHERE (t1.id > 5)) AS a" + ); + + let table_scan_with_two_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(lit(1)), col("age").lt(lit(2))], + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; + assert_eq!( + table_scan_with_two_filter.to_string(), + "SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))" + ); + + let table_scan_with_fetch = + table_scan_with_filter_and_fetch(Some("t1"), &schema, None, vec![], Some(10))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; + assert_eq!( + table_scan_with_fetch.to_string(), + "SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a" + ); + + let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![col("id").gt(lit(1))], + Some(10), + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; + assert_eq!( + table_scan_with_pushdown_all.to_string(), + "SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a" + ); + Ok(()) +} + +#[test] +fn test_table_scan_pushdown() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; assert_eq!( - format!("{}", scan_with_projection), + scan_with_projection.to_string(), "SELECT t1.id, t1.age FROM t1" ); let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; - assert_eq!(format!("{}", scan_with_projection), "SELECT t1.age FROM t1"); + assert_eq!(scan_with_projection.to_string(), "SELECT t1.age FROM t1"); let scan_with_no_projection = table_scan(Some("t1"), &schema, None)?.build()?; let scan_with_no_projection = plan_to_sql(&scan_with_no_projection)?; - assert_eq!(format!("{}", scan_with_no_projection), "SELECT * FROM t1"); + assert_eq!(scan_with_no_projection.to_string(), "SELECT * FROM t1"); let table_scan_with_projection_alias = table_scan(Some("t1"), &schema, Some(vec![0, 1]))? @@ -675,7 +828,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_projection_alias), + table_scan_with_projection_alias.to_string(), "SELECT ta.id, ta.age FROM t1 AS ta" ); @@ -686,7 +839,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_projection_alias), + table_scan_with_projection_alias.to_string(), "SELECT ta.age FROM t1 AS ta" ); @@ -696,7 +849,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_no_projection_alias = plan_to_sql(&table_scan_with_no_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_no_projection_alias), + table_scan_with_no_projection_alias.to_string(), "SELECT * FROM t1 AS ta" ); @@ -708,7 +861,7 @@ fn test_table_scan_pushdown() -> Result<()> { let query_from_table_scan_with_projection = plan_to_sql(&query_from_table_scan_with_projection)?; assert_eq!( - format!("{}", query_from_table_scan_with_projection), + query_from_table_scan_with_projection.to_string(), "SELECT * FROM (SELECT t1.id, t1.age FROM t1)" ); @@ -721,7 +874,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_filter = plan_to_sql(&table_scan_with_filter)?; assert_eq!( - format!("{}", table_scan_with_filter), + table_scan_with_filter.to_string(), "SELECT * FROM t1 WHERE (t1.id > t1.age)" ); @@ -734,7 +887,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; assert_eq!( - format!("{}", table_scan_with_two_filter), + table_scan_with_two_filter.to_string(), "SELECT * FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))" ); @@ -748,7 +901,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_filter_alias = plan_to_sql(&table_scan_with_filter_alias)?; assert_eq!( - format!("{}", table_scan_with_filter_alias), + table_scan_with_filter_alias.to_string(), "SELECT * FROM t1 AS ta WHERE (ta.id > ta.age)" ); @@ -762,7 +915,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; assert_eq!( - format!("{}", table_scan_with_projection_and_filter), + table_scan_with_projection_and_filter.to_string(), "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age)" ); @@ -776,7 +929,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; assert_eq!( - format!("{}", table_scan_with_projection_and_filter), + table_scan_with_projection_and_filter.to_string(), "SELECT t1.age FROM t1 WHERE (t1.id > t1.age)" ); @@ -785,7 +938,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_inline_fetch = plan_to_sql(&table_scan_with_inline_fetch)?; assert_eq!( - format!("{}", table_scan_with_inline_fetch), + table_scan_with_inline_fetch.to_string(), "SELECT * FROM t1 LIMIT 10" ); @@ -800,7 +953,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_inline_fetch = plan_to_sql(&table_scan_with_projection_and_inline_fetch)?; assert_eq!( - format!("{}", table_scan_with_projection_and_inline_fetch), + table_scan_with_projection_and_inline_fetch.to_string(), "SELECT t1.id, t1.age FROM t1 LIMIT 10" ); @@ -814,9 +967,131 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_all = plan_to_sql(&table_scan_with_all)?; assert_eq!( - format!("{}", table_scan_with_all), + table_scan_with_all.to_string(), "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10" ); + + let table_scan_with_additional_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(col("age"))], + )? + .filter(col("id").eq(lit(5)))? + .build()?; + let table_scan_with_filter = plan_to_sql(&table_scan_with_additional_filter)?; + assert_eq!( + table_scan_with_filter.to_string(), + "SELECT * FROM t1 WHERE (t1.id = 5) AND (t1.id > t1.age)" + ); + + Ok(()) +} + +#[test] +fn test_sort_with_push_down_fetch() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id"), col("age")])? + .sort_with_limit(vec![col("age").sort(true, true)], Some(10))? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!( + format!("{}", sql), + "SELECT t1.id, t1.age FROM t1 ORDER BY t1.age ASC NULLS FIRST LIMIT 10" + ); + Ok(()) +} + +#[test] +fn test_join_with_table_scan_filters() -> Result<()> { + let schema_left = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ]); + + let schema_right = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let left_plan = table_scan_with_filters( + Some("left_table"), + &schema_left, + None, + vec![col("name").like(lit("some_name"))], + )? + .alias("left")? + .build()?; + + let right_plan = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .build()?; + + let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan.clone(), + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .build()?; + + let sql = plan_to_sql(&join_plan_with_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + None, + )? + .build()?; + + let sql = plan_to_sql(&join_plan_no_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let right_plan_with_filter = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .filter(col("right_table.name").eq(lit("before_join_filter_val")))? + .build()?; + + let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan_with_filter, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .filter(col("left.name").eq(lit("after_join_filter_val")))? + .build()?; + + let sql = plan_to_sql(&join_plan_multiple_filters)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#; + + assert_eq!(sql.to_string(), expected_sql); + Ok(()) } @@ -845,10 +1120,71 @@ fn test_without_offset() { #[test] fn test_with_offset0() { - sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1"); + sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1 OFFSET 0"); } #[test] fn test_with_offset95() { sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95"); } + +#[test] +fn test_order_by_to_sql() { + // order by aggregation function + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by aggregation function alias + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by scalar function from projection + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#, + r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#, + ); +} + +#[test] +fn test_aggregation_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, + rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1, + rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2 + FROM person + GROUP BY id, first_name;"#, + r#"SELECT person.id, person.first_name, +sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, +max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, +rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, +rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 +FROM person +GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), + ); +} + +#[test] +fn test_unnest_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#, + r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#, + ); + + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#, + ); +} diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index fe0e5f7283a47..b0fa170318493 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -54,6 +54,7 @@ pub(crate) struct MockSessionState { scalar_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, + window_functions: HashMap>, pub config_options: ConfigOptions, } @@ -80,6 +81,12 @@ impl MockSessionState { ); self } + + pub fn with_window_function(mut self, window_function: Arc) -> Self { + self.window_functions + .insert(window_function.name().to_string(), window_function); + self + } } pub(crate) struct MockContextProvider { @@ -217,18 +224,15 @@ impl ContextProvider for MockContextProvider { unimplemented!() } - fn get_window_meta(&self, _name: &str) -> Option> { - None + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions.get(name).cloned() } fn options(&self) -> &ConfigOptions { &self.state.config_options } - fn get_file_type( - &self, - _ext: &str, - ) -> Result> { + fn get_file_type(&self, _ext: &str) -> Result> { Ok(Arc::new(MockCsvType {})) } @@ -268,7 +272,7 @@ impl EmptyTable { } impl TableSource for EmptyTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5c9655a55606a..698c408e538f5 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -48,6 +48,7 @@ use datafusion_functions_aggregate::{ min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; +use datafusion_functions_window::rank::rank_udwf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -897,7 +898,7 @@ fn natural_right_join() { fn natural_join_no_common_becomes_cross_join() { let sql = "SELECT * FROM person a NATURAL JOIN lineitem b"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: a\ \n TableScan: person\ \n SubqueryAlias: b\ @@ -1913,6 +1914,13 @@ fn create_external_table_with_pk() { quick_test(sql, expected); } +#[test] +fn create_external_table_wih_schema() { + let sql = "CREATE EXTERNAL TABLE staging.foo STORED AS CSV LOCATION 'foo.csv'"; + let expected = "CreateExternalTable: Partial { schema: \"staging\", table: \"foo\" }"; + quick_test(sql, expected); +} + #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; @@ -2626,6 +2634,7 @@ fn logical_plan_with_dialect_and_options( .with_aggregate_function(min_udaf()) .with_aggregate_function(max_udaf()) .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; @@ -2735,8 +2744,8 @@ fn cross_join_not_to_inner_join() { "select person.id from person, orders, lineitem where person.id = person.age;"; let expected = "Projection: person.id\ \n Filter: person.id = person.age\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: person\ \n TableScan: orders\ \n TableScan: lineitem"; @@ -2833,11 +2842,11 @@ fn exists_subquery_schema_outer_schema_overlap() { \n Subquery:\ \n Projection: person.first_name\ \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p2\ \n TableScan: person\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2925,10 +2934,10 @@ fn scalar_subquery_reference_outer_field() { \n Projection: count(*)\ \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j3\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; @@ -3052,8 +3061,8 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -3114,7 +3123,7 @@ fn join_on_complex_condition() { fn lateral_constant() { let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3129,7 +3138,7 @@ fn lateral_comma_join() { j1, \ LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2"; let expected = "Projection: j1.j1_string, j2.j2_string\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3145,7 +3154,7 @@ fn lateral_comma_join_referencing_join_rhs() { \n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\ \n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n Inner Join: Filter: j1.j1_id = j2.j2_id\ \n TableScan: j1\ \n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\ @@ -3169,12 +3178,12 @@ fn lateral_comma_join_with_shadowing() { ) as j2\ ) as j2;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ \n Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3206,7 +3215,7 @@ fn lateral_nested_left_join() { j1, \ (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n Left Join: Filter: Boolean(true)\ \n TableScan: j2\ @@ -4200,6 +4209,29 @@ fn test_prepare_statement_to_plan_having() { prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + + let expected_plan = "Prepare: \"my_plan\" [Int64, Int64] \ + \n Limit: skip=$1, fetch=$2\ + \n Projection: person.id\ + \n TableScan: person"; + + let expected_dt = "[Int64, Int64]"; + + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let expected_plan = "Limit: skip=10, fetch=200\ + \n Projection: person.id\ + \n TableScan: person"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_to_plan_value_list() { let sql = "PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter);"; @@ -4272,7 +4304,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ @@ -4290,7 +4322,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ \n Projection: t1.id AS c1, t2.age AS c2\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index baa49057e1b97..2479252a7b5b0 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -61,7 +61,16 @@ async fn run_tests() -> Result<()> { // Enable logging (e.g. set RUST_LOG=debug to see debug logs) env_logger::init(); - let options: Options = clap::Parser::parse(); + let options: Options = Parser::parse(); + if options.list { + // nextest parses stdout, so print messages to stderr + eprintln!("NOTICE: --list option unsupported, quitting"); + // return Ok, not error so that tools like nextest which are listing all + // workspace tests (by running `cargo test ... --list --format terse`) + // do not fail when they encounter this binary. Instead, print nothing + // to stdout and return OK so they can continue listing other tests. + return Ok(()); + } options.warn_on_ignored(); // Run all tests in parallel, reporting failures at the end @@ -255,7 +264,7 @@ fn read_dir_recursive>(path: P) -> Result> { /// Append all paths recursively to dst fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { - let entries = std::fs::read_dir(path) + let entries = fs::read_dir(path) .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; for entry in entries { let path = entry @@ -276,7 +285,7 @@ fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { /// Parsed command line options /// -/// This structure attempts to mimic the command line options +/// This structure attempts to mimic the command line options of the built in rust test runner /// accepted by IDEs such as CLion that pass arguments /// /// See for more details @@ -320,6 +329,18 @@ struct Options { help = "IGNORED (for compatibility with built in rust test runner)" )] show_output: bool, + + #[clap( + long, + help = "Quits immediately, not listing anything (for compatibility with built-in rust test runner)" + )] + list: bool, + + #[clap( + long, + help = "IGNORED (for compatibility with built-in rust test runner)" + )] + ignored: bool, } impl Options { @@ -354,15 +375,15 @@ impl Options { /// Logs warning messages to stdout if any ignored options are passed fn warn_on_ignored(&self) { if self.format.is_some() { - println!("WARNING: Ignoring `--format` compatibility option"); + eprintln!("WARNING: Ignoring `--format` compatibility option"); } if self.z_options.is_some() { - println!("WARNING: Ignoring `-Z` compatibility option"); + eprintln!("WARNING: Ignoring `-Z` compatibility option"); } if self.show_output { - println!("WARNING: Ignoring `--show-output` compatibility option"); + eprintln!("WARNING: Ignoring `--show-output` compatibility option"); } } } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index d3ee720467b66..477f225443e28 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -139,7 +139,7 @@ impl TestContext { } #[cfg(feature = "avro")] -pub async fn register_avro_tables(ctx: &mut crate::TestContext) { +pub async fn register_avro_tables(ctx: &mut TestContext) { use datafusion::prelude::AvroReadOptions; ctx.enable_testdir(); @@ -314,17 +314,49 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { String::from("metadata_key"), String::from("the name field"), )])); - - let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( - String::from("metadata_key"), - String::from("the entire schema"), - )])); + let l_name = + Field::new("l_name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the l_name field"), + )])); + + let ts = Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false) + .with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("ts non-nullable field"), + )])); + + let nonnull_name = + Field::new("nonnull_name", DataType::Utf8, false).with_metadata(HashMap::from([ + ( + String::from("metadata_key"), + String::from("the nonnull_name field"), + ), + ])); + + let schema = Schema::new(vec![id, name, l_name, ts, nonnull_name]).with_metadata( + HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )]), + ); let batch = RecordBatch::try_new( Arc::new(schema), vec![ Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + Arc::new(StringArray::from(vec![None, Some("l_bar"), Some("l_baz")])) as _, + Arc::new(TimestampNanosecondArray::from(vec![ + 1599572549190855123, + 1599572549190855123, + 1599572549190855123, + ])) as _, + Arc::new(StringArray::from(vec![ + Some("no_foo"), + Some("no_bar"), + Some("no_baz"), + ])) as _, ], ) .unwrap(); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 56756cb2010b0..f03c3700ab9f9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1124,6 +1124,14 @@ SELECT COUNT(*) FROM aggregate_test_100 ---- 100 +query I +SELECT COUNT(aggregate_test_100.*) FROM aggregate_test_100 +---- +100 + +query error Error during planning: Invalid qualifier foo +SELECT COUNT(foo.*) FROM aggregate_test_100 + # csv_query_count_literal query I SELECT COUNT(2) FROM aggregate_test_100 @@ -1377,6 +1385,24 @@ NaN statement ok DROP TABLE tmp_percentile_cont; +# Test for issue where approx_percentile_cont_with_weight + +statement ok +CREATE TABLE t1(v1 BOOL); + +statement ok +INSERT INTO t1 VALUES (TRUE); + +# ISSUE: https://github.com/apache/datafusion/issues/12716 +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +query R +SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +---- +Infinity + +statement ok +DROP TABLE t1; + # csv_query_cube_avg query TIR SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 @@ -3512,6 +3538,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls ---- 1 5 +# grouping_sets with null values +query II rowsort +SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value) +---- +1 1 +3 3 +4 4 +5 5 +NULL 1 +NULL NULL + + statement ok DROP TABLE integers_with_nulls; @@ -3780,6 +3818,180 @@ DROP TABLE min_bool; # Min_Max End # ################# + + +################# +# min_max on strings/binary with null values and groups +################# + +statement ok +CREATE TABLE strings (value TEXT, id int); + +statement ok +INSERT INTO strings VALUES + ('c', 1), + ('d', 1), + ('a', 3), + ('c', 1), + ('b', 1), + (NULL, 1), + (NULL, 4), + ('d', 1), + ('z', 2), + ('c', 1), + ('a', 2); + +############ Utf8 ############ + +query IT +SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +############ LargeUtf8 ############ + +statement ok +CREATE VIEW large_strings AS SELECT id, arrow_cast(value, 'LargeUtf8') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW large_strings + +############ Utf8View ############ + +statement ok +CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW string_views + +############ Binary ############ + +statement ok +CREATE VIEW binary AS SELECT id, arrow_cast(value, 'Binary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary + +############ LargeBinary ############ + +statement ok +CREATE VIEW large_binary AS SELECT id, arrow_cast(value, 'LargeBinary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW large_binary + +############ BinaryView ############ + +statement ok +CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary_views + +statement ok +DROP TABLE strings; + +################# +# End min_max on strings/binary with null values and groups +################# + + statement ok create table bool_aggregate_functions ( c1 boolean not null, @@ -4871,16 +5083,18 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Limit: skip=0, fetch=3 -02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -03)----TableScan: aggregate_test_100 projection=[c2, c3] +01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 +02)--Limit: skip=0, fetch=3 +03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true +01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3] +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -5863,63 +6077,12 @@ ORDER BY k; 1 1.8125 6.8007813 Float16 Float16 2 8.5 8.5 Float16 Float16 -# The result is 0.19432323191699075 actually -query R -SELECT kurtosis_pop(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col); ----- -0.194323231917 - -# The result is -1.153061224489787 actually -query R -SELECT kurtosis_pop(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col); ----- --1.15306122449 - -query R -SELECT kurtosis_pop(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col); ----- -0.194323231917 - -query R -SELECT kurtosis_pop(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col); ----- -0.194323231917 - -query R -SELECT kurtosis_pop(col) FROM VALUES (1.0) as tab(col); ----- -NULL - -query R -SELECT kurtosis_pop(1) ----- -NULL - -query R -SELECT kurtosis_pop(1.0) ----- -NULL - -query R -SELECT kurtosis_pop(null) ----- -NULL - statement ok -CREATE TABLE t1(c1 int); - -query R -SELECT kurtosis_pop(c1) FROM t1; ----- -NULL +CREATE TABLE t1(v1 int); -statement ok -INSERT INTO t1 VALUES (1), (10), (100), (10), (1); - -query R -SELECT kurtosis_pop(c1) FROM t1; ----- -0.194323231917 +# issue: https://github.com/apache/datafusion/issues/12814 +statement error DataFusion error: Error during planning: Aggregate functions are not allowed in the WHERE clause. Consider using HAVING instead +SELECT v1 FROM t1 WHERE ((count(v1) % 1) << 1) > 0; statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 2209edc5d1fc4..a67fec695f6c6 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -53,6 +53,11 @@ physical_plan 07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] 08)--------------MemoryExec: partitions=1, partition_sizes=[1] +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where trace_id != 'b' order by max_ts desc limit 3; +---- +c 4 +a 1 query TI select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; @@ -89,6 +94,12 @@ c 1 2 statement ok set datafusion.optimizer.enable_topk_aggregation = true; +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where max_ts != 3 order by max_ts desc limit 2; +---- +c 4 +a 1 + query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ---- diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b7d60b50586dd..bfdbfb1bcc5e2 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6595,7 +6595,7 @@ select make_array(1, 2.0, null, 3) query ? select make_array(1.0, '2', null) ---- -[1.0, 2, ] +[1.0, 2.0, ] ### FixedSizeListArray @@ -7097,6 +7097,19 @@ select array_has(a, 1) from values_all_empty; false false +# Test create table with fixed sized array +statement ok +create table fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5,6]); + +query T +select arrow_typeof(a) from fixed_size_col_table; +---- +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) + +statement error +create table varying_fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5]); + ### Delete tables statement ok @@ -7272,3 +7285,6 @@ drop table test_create_array_table; statement ok drop table values_all_empty; + +statement ok +drop table fixed_size_col_table; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index e66ba7477fc48..e73acc384cb3e 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -118,3 +118,8 @@ EXPLAIN SELECT f0 FROM arrow_partitioned WHERE part = 456 ---- logical_plan TableScan: arrow_partitioned projection=[f0], full_filters=[arrow_partitioned.part = Int32(456)] physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow]]}, projection=[f0] + + +# Errors in partition filters should be reported +query error Divide by zero error +SELECT f0 FROM arrow_partitioned WHERE CASE WHEN true THEN 1 / 0 ELSE part END = 1; diff --git a/datafusion/sqllogictest/test_files/binary_view.slt b/datafusion/sqllogictest/test_files/binary_view.slt index 77ec77c5eccee..f973b909aeb6b 100644 --- a/datafusion/sqllogictest/test_files/binary_view.slt +++ b/datafusion/sqllogictest/test_files/binary_view.slt @@ -200,3 +200,18 @@ NULL R NULL NULL NULL NULL statement ok drop table test; + +statement ok +create table bv as values +( + arrow_cast('one', 'BinaryView'), + arrow_cast('two', 'BinaryView') +); + +query B +select column1 like 'o%' from bv; +---- +true + +statement ok +drop table bv; diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 12b097c3d5d11..ed001cf9f84c5 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -102,9 +102,6 @@ CREATE TEMPORARY TABLE my_temp_table ( name TEXT NOT NULL ); -statement error DataFusion error: This feature is not implemented: Temporary views not supported -CREATE TEMPORARY VIEW my_temp_view AS SELECT id, name FROM my_table; - # Partitioned table on a single file query error DataFusion error: Error during planning: Can't create a partitioned table backed by a single file, perhaps the URL is missing a trailing slash\? CREATE EXTERNAL TABLE single_file_partition(c1 int) @@ -275,3 +272,14 @@ DROP TABLE t; # query should fail with bad column statement error DataFusion error: Error during planning: Column foo is not in schema CREATE EXTERNAL TABLE t STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet' WITH ORDER (foo); + +# Create external table with qualified name should belong to the schema +statement ok +CREATE SCHEMA staging; + +statement ok +CREATE EXTERNAL TABLE staging.foo STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + +# Create external table with qualified name, but no schema should error +statement error DataFusion error: Error during planning: failed to resolve schema: release +CREATE EXTERNAL TABLE release.bar STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt index 19b47fa50e410..c95e9a1309f8d 100644 --- a/datafusion/sqllogictest/test_files/cse.slt +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -179,8 +179,8 @@ physical_plan # Surely only once but also conditionally evaluated expressions query TT EXPLAIN SELECT - (a = 1 OR random() = 0) AND a = 1 AS c1, - (a = 2 AND random() = 0) OR a = 2 AS c2, + (a = 1 OR random() = 0) AND a = 2 AS c1, + (a = 2 AND random() = 0) OR a = 1 AS c2, CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3, CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4, CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5, @@ -188,37 +188,37 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_1 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_2 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4, t1.a + Float64(5) AS __common_expr_5, t1.a + Float64(6) AS __common_expr_6 03)----TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_1@0 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_2@1 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] 02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4, a@0 + 5 as __common_expr_5, a@0 + 6 as __common_expr_6] 03)----MemoryExec: partitions=1, partition_sizes=[0] # Surely only once but also conditionally evaluated subexpressions query TT EXPLAIN SELECT - (a = 1 OR random() = 0) AND (a = 1 OR random() = 1) AS c1, - (a = 2 AND random() = 0) OR (a = 2 AND random() = 1) AS c2, + (a = 1 OR random() = 0) AND (a = 2 OR random() = 1) AS c1, + (a = 2 AND random() = 0) OR (a = 1 AND random() = 1) AS c2, CASE WHEN a + 3 = 0 THEN a + 3 + random() ELSE 0 END AS c3, CASE WHEN a + 4 = 0 THEN 0 ELSE a + 4 + random() END AS c4 FROM t1 ---- logical_plan -01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_1 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_2 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4 03)----TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_1@0 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_2@1 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] 02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4] 03)----MemoryExec: partitions=1, partition_sizes=[0] # Only conditionally evaluated expressions query TT EXPLAIN SELECT - (random() = 0 OR a = 1) AND a = 1 AS c1, - (random() = 0 AND a = 2) OR a = 2 AS c2, + (random() = 0 OR a = 1) AND a = 2 AS c1, + (random() = 0 AND a = 2) OR a = 1 AS c2, CASE WHEN random() = 0 THEN a + 3 ELSE a + 3 END AS c3, CASE WHEN random() = 0 THEN 0 WHEN a + 4 = 0 THEN a + 4 ELSE 0 END AS c4, CASE WHEN random() = 0 THEN 0 WHEN a + 5 = 0 THEN 0 ELSE a + 5 END AS c5, @@ -226,8 +226,8 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(1) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(2) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 +01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(2) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(1) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 02)--TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 1 as c1, random() = 0 AND a@0 = 2 OR a@0 = 2 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] +01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 2 as c1, random() = 0 AND a@0 = 2 OR a@0 = 1 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] 02)--MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e9fcf07e7739c..60569803322cf 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -722,7 +722,7 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation 05)----Projection: Int64(2) AS val -06)------CrossJoin: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte diff --git a/datafusion/sqllogictest/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt index 1ef56b1a7e11d..4425eee333735 100644 --- a/datafusion/sqllogictest/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -194,6 +194,14 @@ create table ts_utf8_data(ts varchar(100), format varchar(100)) as values ('1926632005', '%s'), ('2000-01-01T01:01:01+07:00', '%+'); +statement ok +create table ts_largeutf8_data as +select arrow_cast(ts, 'LargeUtf8') as ts, arrow_cast(format, 'LargeUtf8') as format from ts_utf8_data; + +statement ok +create table ts_utf8view_data as +select arrow_cast(ts, 'Utf8View') as ts, arrow_cast(format, 'Utf8View') as format from ts_utf8_data; + # verify date data using tables with formatting options query D SELECT to_date(t.ts, t.format) from ts_utf8_data as t @@ -204,6 +212,24 @@ SELECT to_date(t.ts, t.format) from ts_utf8_data as t 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, t.format) from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, t.format) from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # verify date data using tables with formatting options query D SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t @@ -214,6 +240,24 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # verify date data using tables with formatting options where at least one column cannot be parsed query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t @@ -228,6 +272,24 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # timestamp data using tables with formatting options in an array is not supported at this time query error function unsupported data type at index 1: SELECT to_date(t.ts, make_array('%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+')) from ts_utf8_data as t diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 21edb458fe567..3205920d71102 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -775,3 +775,33 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te statement ok drop table t; + +statement ok +set datafusion.explain.logical_plan_only=true; + +query TT +explain CREATE TEMPORARY VIEW z AS VALUES (1,2,3); +---- +logical_plan +01)CreateView: Bare { table: "z" } +02)--Values: (Int64(1), Int64(2), Int64(3)) + +query TT +explain CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; +---- +logical_plan CreateExternalTable: Bare { table: "tty" } + +statement ok +set datafusion.explain.logical_plan_only=false; + +statement error DataFusion error: This feature is not implemented: Temporary tables not supported +CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; + +statement error DataFusion error: This feature is not implemented: Temporary views not supported +CREATE TEMPORARY VIEW y AS VALUES (1,2,3); + +query error DataFusion error: Schema error: No field named a\. +EXPLAIN CREATE TABLE t(a int) AS VALUES (a + a); + +statement error DataFusion error: Schema error: No field named a\. +CREATE TABLE t(a int) AS SELECT x FROM (VALUES (a)) t(x) WHERE false; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index 99639d78c3090..604ac95ff476f 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -144,6 +144,19 @@ LIMIT 3; 45 15673 -72 -11122 +# use wildcard +query TIIIIIIIITRRT +SELECT DISTINCT ON (c1) * FROM aggregate_test_100 ORDER BY c1 LIMIT 3; +---- +a 1 -85 -15154 1171968280 1919439543497968449 77 52286 774637006 12101411955859039553 0.12285209 0.686439196277 0keZ5G8BffGwgF2RwQD59TFzMStxCB +b 1 29 -18218 994303988 5983957848665088916 204 9489 3275293996 14857091259186476033 0.53840446 0.179090351188 AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz +c 2 1 18109 2033001162 -6513304855495910254 25 43062 1491205016 5863949479783605708 0.110830784 0.929409733247 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW + +# can't distinct on * +query error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \*"\) +SELECT DISTINCT ON (*) c1 FROM aggregate_test_100 ORDER BY c1 LIMIT 3; + + # test distinct on statement ok create table t(a int, b int, c int) as values (1, 2, 3); diff --git a/datafusion/sqllogictest/test_files/dynamic_file.slt b/datafusion/sqllogictest/test_files/dynamic_file.slt index e177fd3de2437..69f9a43ad4077 100644 --- a/datafusion/sqllogictest/test_files/dynamic_file.slt +++ b/datafusion/sqllogictest/test_files/dynamic_file.slt @@ -25,9 +25,170 @@ SELECT * FROM '../core/tests/data/partitioned_table_arrow/part=123' ORDER BY f0; 1 foo true 2 bar false -# dynamic file query doesn't support partitioned table -statement error DataFusion error: Error during planning: table 'datafusion.public.../core/tests/data/partitioned_table_arrow' not found -SELECT * FROM '../core/tests/data/partitioned_table_arrow' ORDER BY f0; +# Read partitioned file +statement ok +CREATE TABLE src_table_1 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 1), +(3, 'ccc', 300, 1), +(4, 'ddd', 400, 1); + +statement ok +CREATE TABLE src_table_2 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(5, 'eee', 500, 2), +(6, 'fff', 600, 2), +(7, 'ggg', 700, 2), +(8, 'hhh', 800, 2); + +# Read partitioned csv file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/csv_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned json file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/json_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned arrow file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/arrow_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +select * from 'test_files/scratch/dynamic_file/parquet_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file with multiple partition columns + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query IITT rowsort +select * from 'test_files/scratch/dynamic_file/nested_partition'; +---- +1 100 1 aaa +2 200 1 bbb +3 300 1 ccc +4 400 1 ddd +5 500 2 eee +6 600 2 fff +7 700 2 ggg +8 800 2 hhh # read avro file query IT diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index be7fdac71b57d..da46a7e5e6796 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -128,5 +128,12 @@ from aggregate_test_100 order by c9 -statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type create table foo as values (1), ('foo'); + +query error No function matches +select 1 group by substr(''); + +# Error in filter should be reported +query error Divide by zero +SELECT c2 from aggregate_test_100 where CASE WHEN true THEN 1 / 0 ELSE 0 END = 1; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 6dc92bae828b8..1340fd490e06f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan 02)--TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE +logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE @@ -187,8 +188,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE @@ -214,8 +213,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e887b1934e046..5b6017b08a00a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -18,46 +18,6 @@ # unicode expressions -query I -SELECT char_length('') ----- -0 - -query I -SELECT char_length('chars') ----- -5 - -query I -SELECT char_length('josé') ----- -4 - -query I -SELECT char_length(NULL) ----- -NULL - -query I -SELECT character_length('') ----- -0 - -query I -SELECT character_length('chars') ----- -5 - -query I -SELECT character_length('josé') ----- -4 - -query I -SELECT character_length(NULL) ----- -NULL - query T SELECT left('abcde', -2) ---- @@ -133,152 +93,6 @@ SELECT length(NULL) ---- NULL -query T -SELECT lpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', -1) ----- -(empty) - -query T -SELECT lpad('hi', 0) ----- -(empty) - -query T -SELECT lpad('hi', 21, 'abcdef') ----- -abcdefabcdefabcdefahi - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', 5, NULL) ----- -NULL - -query T -SELECT lpad('hi', 5) ----- - hi - -query T -SELECT lpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5) ----- - hi - -query T -SELECT lpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - -query T -SELECT lpad('hi', CAST(NULL AS INT)) ----- -NULL - -query T -SELECT lpad('xyxhi', 3) ----- -xyx - -query T -SELECT lpad(NULL, 0) ----- -NULL - -query T -SELECT lpad(NULL, 5, 'xy') ----- -NULL - -# test largeutf8, utf8view for lpad -query T -SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) ----- -xyxhi - -query T -SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) ----- -xyxhi - -query T -SELECT lpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ----- -NULL - -query T -SELECT reverse('abcde') ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'LargeUtf8')) ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'Utf8View')) ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) ----- -edcba - -query T -SELECT reverse('loẅks') ----- -sk̈wol - -query T -SELECT reverse(arrow_cast('loẅks', 'LargeUtf8')) ----- -sk̈wol - -query T -SELECT reverse(arrow_cast('loẅks', 'Utf8View')) ----- -sk̈wol - -query T -SELECT reverse(NULL) ----- -NULL - -query T -SELECT reverse(arrow_cast(NULL, 'LargeUtf8')) ----- -NULL - -query T -SELECT reverse(arrow_cast(NULL, 'Utf8View')) ----- -NULL - query T SELECT right('abcde', -2) ---- @@ -324,124 +138,6 @@ SELECT right(NULL, CAST(NULL AS INT)) ---- NULL - -query T -SELECT rpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', -1) ----- -(empty) - -query T -SELECT rpad('hi', 0) ----- -(empty) - -query T -SELECT rpad('hi', 21, 'abcdef') ----- -hiabcdefabcdefabcdefa - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', 5, NULL) ----- -NULL - -query T -SELECT rpad('hi', 5) ----- -hi - -query T -SELECT rpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - -query T -SELECT rpad('hi', CAST(NULL AS INT)) ----- -NULL - -query T -SELECT rpad('xyxhi', 3) ----- -xyx - -# test for rpad with largeutf8 and utf8View - -query T -SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) ----- -hixyx - -query T -SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') ----- -NULL - -query I -SELECT strpos('abc', 'c') ----- -3 - -query I -SELECT strpos('josé', 'é') ----- -4 - -query I -SELECT strpos('joséésoj', 'so') ----- -6 - -query I -SELECT strpos('joséésoj', 'abc') ----- -0 - -query I -SELECT strpos(NULL, 'abc') ----- -NULL - -query I -SELECT strpos('joséésoj', NULL) ----- -NULL - query T SELECT substr('alphabet', -3) ---- @@ -796,45 +492,6 @@ SELECT md5(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- acbd18db4cc2f85cedef654fccc4a4d8 -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT repeat('foo', 3) ----- -foofoofoo - -query T -SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) ----- -foofoofoo - -query T -SELECT replace('foobar', 'bar', 'hello') ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) ----- -foohello query T SELECT rtrim(' foo ') @@ -846,68 +503,6 @@ SELECT rtrim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) ---- foo -query T -SELECT split_part('foo_bar', '_', 2) ----- -bar - -query T -SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ----- -bar - -# test largeutf8, utf8view for split_part -query T -SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) ----- -large - -query T -SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); ----- -view - -query T -SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) ----- -_split - -query T -SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) ----- -large - -query T -SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) ----- -_banana - -query T -SELECT split_part(NULL, '_', 2) ----- -NULL - - -query B -SELECT starts_with('foobar', 'foo') ----- -true - -query B -SELECT starts_with('foobar', 'bar') ----- -false - -query B -SELECT ends_with('foobar', 'bar') ----- -true - -query B -SELECT ends_with('foobar', 'foo') ----- -false - query T SELECT trim(' foo ') ---- @@ -958,6 +553,16 @@ SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), 'world') ---- 6 +query I +SELECT strpos('helloworld', NULL) +---- +NULL + +query I +SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), NULL) +---- +NULL + statement ok CREATE TABLE products ( product_id INT PRIMARY KEY, @@ -1064,279 +669,6 @@ NULL Thomxas NULL -query I -SELECT levenshtein('kitten', 'sitting') ----- -3 - -query I -SELECT levenshtein('kitten', NULL) ----- -NULL - -query I -SELECT levenshtein(NULL, 'sitting') ----- -NULL - -query I -SELECT levenshtein(NULL, NULL) ----- -NULL - -# Test substring_index using '.' as delimiter -# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results -query TIT -SELECT str, n, substring_index(str, '.', n) AS c FROM - (VALUES - ROW('arrow.apache.org'), - ROW('.'), - ROW('...'), - ROW(NULL) - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(3), - ROW(100), - ROW(-1), - ROW(-2), - ROW(-3), - ROW(-100) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -NULL -100 NULL -NULL -3 NULL -NULL -2 NULL -NULL -1 NULL -NULL 1 NULL -NULL 2 NULL -NULL 3 NULL -NULL 100 NULL -arrow.apache.org -100 arrow.apache.org -arrow.apache.org -3 arrow.apache.org -arrow.apache.org -2 apache.org -arrow.apache.org -1 org -arrow.apache.org 1 arrow -arrow.apache.org 2 arrow.apache -arrow.apache.org 3 arrow.apache.org -arrow.apache.org 100 arrow.apache.org -... -100 ... -... -3 .. -... -2 . -... -1 (empty) -... 1 (empty) -... 2 . -... 3 .. -... 100 ... -. -100 . -. -3 . -. -2 . -. -1 (empty) -. 1 (empty) -. 2 . -. 3 . -. 100 . - -query I -SELECT levenshtein(NULL, NULL) ----- -NULL - -# Test substring_index using '.' as delimiter with utf8view -query TIT -SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM - (VALUES - ROW('arrow.apache.org'), - ROW('.'), - ROW('...'), - ROW(NULL) - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(3), - ROW(100), - ROW(-1), - ROW(-2), - ROW(-3), - ROW(-100) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -NULL -100 NULL -NULL -3 NULL -NULL -2 NULL -NULL -1 NULL -NULL 1 NULL -NULL 2 NULL -NULL 3 NULL -NULL 100 NULL -arrow.apache.org -100 arrow.apache.org -arrow.apache.org -3 arrow.apache.org -arrow.apache.org -2 apache.org -arrow.apache.org -1 org -arrow.apache.org 1 arrow -arrow.apache.org 2 arrow.apache -arrow.apache.org 3 arrow.apache.org -arrow.apache.org 100 arrow.apache.org -... -100 ... -... -3 .. -... -2 . -... -1 (empty) -... 1 (empty) -... 2 . -... 3 .. -... 100 ... -. -100 . -. -3 . -. -2 . -. -1 (empty) -. 1 (empty) -. 2 . -. 3 . -. 100 . - -# Test substring_index using 'ac' as delimiter -query TIT -SELECT str, n, substring_index(str, 'ac', n) AS c FROM - (VALUES - -- input string does not contain the delimiter - ROW('arrow'), - -- input string contains the delimiter - ROW('arrow.apache.org') - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(-1), - ROW(-2) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -arrow.apache.org -2 arrow.apache.org -arrow.apache.org -1 he.org -arrow.apache.org 1 arrow.ap -arrow.apache.org 2 arrow.apache.org -arrow -2 arrow -arrow -1 arrow -arrow 1 arrow -arrow 2 arrow - -# Test substring_index with NULL values -query TTTT -SELECT - substring_index(NULL, '.', 1), - substring_index('arrow.apache.org', NULL, 1), - substring_index('arrow.apache.org', '.', NULL), - substring_index(NULL, NULL, NULL) ----- -NULL NULL NULL NULL - -# Test substring_index with empty strings -query TT -SELECT - -- input string is empty - substring_index('', '.', 1), - -- delimiter is empty - substring_index('arrow.apache.org', '', 1) ----- -(empty) (empty) - -# Test substring_index with 0 occurrence -query T -SELECT substring_index('arrow.apache.org', 'ac', 0) ----- -(empty) - -# Test substring_index with large occurrences -query TT -SELECT - -- i64::MIN - substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, - -- i64::MAX - substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; ----- -arrow.apache.org arrow.apache.org - -# Test substring_index issue https://github.com/apache/datafusion/issues/9472 -query TTT -SELECT - url, - substring_index(url, '.', 1) AS subdomain, - substring_index(url, '.', -1) AS tld -FROM - (VALUES ROW('docs.apache.com'), - ROW('community.influxdata.com'), - ROW('arrow.apache.org') - ) data(url) ----- -docs.apache.com docs com -community.influxdata.com community com -arrow.apache.org arrow org - -# find_in_set tests -query I -SELECT find_in_set('b', 'a,b,c,d') ----- -2 - - -query I -SELECT find_in_set('a', 'a,b,c,d,a') ----- -1 - -query I -SELECT find_in_set('', 'a,b,c,d,a') ----- -0 - -query I -SELECT find_in_set('a', '') ----- -0 - - -query I -SELECT find_in_set('', '') ----- -1 - -query I -SELECT find_in_set(NULL, 'a,b,c,d') ----- -NULL - -query I -SELECT find_in_set('a', NULL) ----- -NULL - - -query I -SELECT find_in_set(NULL, NULL) ----- -NULL - -# find_in_set tests with utf8view -query I -SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') ----- -2 - - -query I -SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) ----- -1 - -query I -SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) ----- -0 - # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index f561fa9e9ac8d..61b3ad73cd0a5 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -3360,7 +3360,8 @@ physical_plan 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] -08)--------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] query IRI SELECT s.sn, s.amount, 2*s.sn @@ -3430,9 +3431,9 @@ physical_plan 07)------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[sum(l.amount)] 08)--------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] 09)----------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1 -10)------------------CoalescePartitionsExec -11)--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] -12)------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +12)--------------------MemoryExec: partitions=1, partition_sizes=[1] query IRR SELECT r.sn, SUM(l.amount), r.amount @@ -3579,8 +3580,7 @@ physical_plan 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 09)----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] 10)------------------BoundedWindowAggExec: wdw=[sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -11)--------------------CoalescePartitionsExec -12)----------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +11)--------------------MemoryExec: partitions=1, partition_sizes=[1] query ITIPTRR @@ -4050,7 +4050,7 @@ EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 ---- logical_plan 01)Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: lhs 04)------Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] @@ -5152,8 +5152,6 @@ drop table test_case_expr statement ok drop table t; -# TODO: Current grouping set result is not align with Postgres and DuckDB, we might want to change the result -# See https://github.com/apache/datafusion/issues/12570 # test multi group by for binary type with nulls statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); @@ -5162,11 +5160,14 @@ query I?I select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); ---- 1 0a 2 -2 NULL 2 -NULL 0b 4 +2 NULL 1 +NULL 0b 2 1 NULL 2 -NULL NULL 3 +2 NULL 1 +NULL NULL 2 NULL 0a 2 +NULL NULL 1 +NULL 0b 2 statement ok drop table t; @@ -5207,3 +5208,65 @@ NULL a 2 statement ok drop table t; + +# test multi group by int + utf8view +statement ok +create table source as values +-- use some strings that are larger than 12 characters as that goes through a different path +(1, 'a'), +(1, 'a'), +(2, 'thisstringislongerthan12'), +(2, 'thisstring'), +(3, 'abc'), +(3, 'cba'), +(2, 'thisstring'), +(null, null), +(null, 'a'), +(null, null), +(null, 'a'), +(2, 'thisstringisalsolongerthan12'), +(2, 'thisstringislongerthan12'), +(1, 'null') +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Utf8View') as b from source; + +query ITI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 a 2 +1 null 1 +2 thisstring 2 +2 thisstringisalsolongerthan12 1 +2 thisstringislongerthan12 2 +3 abc 1 +3 cba 1 +NULL a 2 +NULL NULL 2 + +statement ok +drop view t + +# test with binary view +statement ok +create view t as select column1 as a, arrow_cast(column2, 'BinaryView') as b from source; + +query I?I +select a, b, count(*) from t group by a, b order by a, b; +---- +1 61 2 +1 6e756c6c 1 +2 74686973737472696e67 2 +2 74686973737472696e676973616c736f6c6f6e6765727468616e3132 1 +2 74686973737472696e6769736c6f6e6765727468616e3132 2 +3 616263 1 +3 636261 1 +NULL 61 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt new file mode 100644 index 0000000000000..64d040d012f99 --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values +('a','A',1), ('b','B',2) + +# grouping_with_grouping_sets +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + grouping sets ( + (c1, c2), + (c1), + (c2), + () + ) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_cube +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + cube(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_rollup +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL NULL 1 1 3 3 + +query TTIIIIIIII +select + c1, + c2, + c3, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3, + grouping(c1, c2, c3) as g4, + grouping(c2, c3, c1) as g5, + grouping(c3, c2, c1) as g6 +from + test +group by + rollup(c1, c2, c3) +order by + c1, c2, g0, g1, g2, g3, g4, g5, g6; +---- +a A 1 0 0 0 0 0 0 0 +a A NULL 0 0 0 0 1 2 4 +a NULL NULL 0 1 1 2 3 6 6 +b B 2 0 0 0 0 0 0 0 +b B NULL 0 0 0 0 1 2 4 +b NULL NULL 0 1 1 2 3 6 6 +NULL NULL NULL 1 1 3 3 7 7 7 + +# grouping_with_add +query TTI +select + c1, + c2, + grouping(c1)+grouping(c2) as g0 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0; +---- +a A 0 +a NULL 1 +b B 0 +b NULL 1 +NULL NULL 2 + +#grouping_with_windown_function +query TTIII +select + c1, + c2, + count(c1) as cnt, + grouping(c1)+ grouping(c2) as g0, + rank() over ( + partition by grouping(c1)+grouping(c2), + case when grouping(c2) = 0 then c1 end + order by + count(c1) desc + ) as rank_within_parent +from + test +group by + rollup(c1, c2) +order by + c1, + c2, + cnt, + g0 desc, + rank_within_parent; +---- +a A 1 0 1 +a NULL 1 1 1 +b B 1 0 1 +b NULL 1 1 1 +NULL NULL 2 2 1 + +# grouping_with_non_columns +query TIIIII +select + c1, + c3 + 1 as c3_add_one, + grouping(c1) as g0, + grouping(c3 + 1) as g1, + grouping(c1, c3 + 1) as g2, + grouping(c3 + 1, c1) as g3 +from + test +group by + grouping sets ( + (c1, c3 + 1), + (c3 + 1), + (c1) + ) +order by + c1, c3_add_one, g0, g1, g2, g3; +---- +a 2 0 0 0 0 +a NULL 0 1 1 2 +b 3 0 0 0 0 +b NULL 0 1 1 2 +NULL 2 1 0 2 1 +NULL 3 1 0 2 1 + +# postgres allows grouping function for GROUP BY without GROUPING SETS/ROLLUP/CUBE +query TI +select c1, grouping(c1) from test group by c1 order by c1; +---- +a 0 +b 0 + +statement error c2.*not in grouping columns +select c1, grouping(c2) from test group by c1; + +statement error c2.*not in grouping columns +select c1, grouping(c1, c2) from test group by CUBE(c1); + +statement error zero arguments +select c1, grouping() from test group by CUBE(c1); diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7acdf25b65967..3630f6c365959 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -173,12 +173,14 @@ datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false datafusion.execution.enable_recursive_ctes true +datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 datafusion.execution.parquet.allow_single_file_parallelism true +datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL datafusion.execution.parquet.bloom_filter_on_read true @@ -263,12 +265,14 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs +datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 8d801b92c3933..39f903a587143 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -671,7 +671,7 @@ query TT explain select * from t1 inner join t2 on true; ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan @@ -905,7 +905,7 @@ JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: e 03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") 04)------TableScan: employees projection=[emp_id, name] @@ -1152,7 +1152,7 @@ logical_plan 01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1 02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0) 03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4 -04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1 +04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1 05)--------TableScan: t1 projection=[v0, v1] 06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4] 07)----TableScan: t0 projection=[v0, v1] @@ -1215,14 +1215,14 @@ statement ok create table t1(v1 int) as values(100); ## Query with Ambiguous column reference -query error DataFusion error: Schema error: Ambiguous reference to unqualified field v1 +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 select count(*) from t1 right outer join t1 on t1.v1 > 0; -query error DataFusion error: Schema error: Ambiguous reference to unqualified field v1 +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 select t1.v1 from t1 join t1 using(v1) cross join (select struct('foo' as v1) as t1); statement ok -drop table t1; \ No newline at end of file +drop table t1; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp deleted file mode 100644 index 00e74a207b333..0000000000000 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## Join Tests -########## - -# turn off repartition_joins -statement ok -set datafusion.optimizer.repartition_joins = false; - -include ./join.slt diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index a7a252cc20d7a..bc40f845cc8ac 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3901,8 +3901,8 @@ SELECT * FROM ( ) AS rhs ON lhs.b=rhs.b ---- 11 1 21 1 -14 2 22 2 12 3 23 3 +14 2 22 2 15 4 24 4 query TT @@ -3922,11 +3922,12 @@ logical_plan 05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 06)------TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] @@ -3979,10 +3980,11 @@ logical_plan 04)--SubqueryAlias: rhs 05)----TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------MemoryExec: partitions=1, partition_sizes=[1] # Null build indices: @@ -4038,11 +4040,12 @@ logical_plan 05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 06)------TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] # Test CROSS JOIN LATERAL syntax (planning) @@ -4050,7 +4053,7 @@ query TT explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: t1 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series @@ -4187,4 +4190,103 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, y@1)], filter=a@0 < x@1 03)----MemoryExec: partitions=1, partition_sizes=[0] 04)----SortExec: expr=[x@0 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[0] \ No newline at end of file +05)------MemoryExec: partitions=1, partition_sizes=[0] + +# Test full join with limit +statement ok +CREATE TABLE t0(c1 INT UNSIGNED, c2 INT UNSIGNED) +AS VALUES +(1, 1), +(2, 2), +(3, 3), +(4, 4); + +statement ok +CREATE TABLE t1(c1 INT UNSIGNED, c2 INT UNSIGNED, c3 BOOLEAN) +AS VALUES +(2, 2, true), +(2, 2, false), +(3, 3, true), +(3, 3, false); + +query IIIIB +SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 2; +---- +2 2 2 2 true +2 2 2 2 false + +query IIIIB +SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; +---- +2 2 2 2 true +3 3 2 2 true + +query IIIIB +SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 2; +---- +2 2 2 2 true +2 2 2 2 false + +## Test !join.on.is_empty() && join.filter.is_none() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: t0.c1 = t1.c1 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)CoalesceBatchesExec: target_batch_size=3, fetch=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +## Test join.on.is_empty() && join.filter.is_some() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: Filter: t0.c2 >= t1.c2 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)GlobalLimitExec: skip=0, fetch=2 +02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +## Test !join.on.is_empty() && join.filter.is_some() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: t0.c1 = t1.c1 Filter: t0.c2 >= t1.c2 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)CoalesceBatchesExec: target_batch_size=3, fetch=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +# Test Utf8View as Join Key +# Issue: https://github.com/apache/datafusion/issues/12468 +statement ok +CREATE TABLE table1(v1 STRING) AS VALUES ('foo'), (NULL); + +statement ok +CREATE TABLE table1_stringview AS SELECT arrow_cast(v1, 'Utf8View') AS v1 FROM table1; + +query T +select * from table1 as t1 natural join table1_stringview as t2; +---- +foo diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 45e1b51a09b41..726de75b51411 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -148,18 +148,17 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); {[1, 2]: [a, b], [3, 4]: [b]} query ? -SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30); ---- -{POST: 41, HEAD: ab, PATCH: 30} +{POST: 41, HEAD: 53, PATCH: 30} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +# Map keys can not be NULL query error SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); -query ? -SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); ----- -{POST: 41, HEAD: ab, PATCH: 30} - query ? SELECT MAKE_MAP() ---- @@ -517,9 +516,12 @@ query error SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL]; query ? -SELECT MAP { 'a': 1, 2: 3 }; +SELECT MAP { 'a': 1, 'b': 3 }; ---- -{a: 1, 2: 3} +{a: 1, b: 3} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT MAP { 'a': 1, 2: 3 }; # TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key # query ? @@ -610,9 +612,12 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) # Tests for map_keys query ? -SELECT map_keys(MAP { 'a': 1, 2: 3 }); +SELECT map_keys(MAP { 'a': 1, 'b': 3 }); ---- -[a, 2] +[a, b] + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_keys(MAP { 'a': 1, 2: 3 }); query ? SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t; @@ -657,8 +662,11 @@ SELECT map_keys(column1) from map_array_table_1; # Tests for map_values -query ? +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type SELECT map_values(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_values(MAP { 'a': 1, 'b': 3 }); ---- [1, 3] diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index eece569423177..1bc972a3e37da 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -102,7 +102,12 @@ SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) # isnan query BBBB -SELECT isnan(1.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +SELECT isnan(1.0::DOUBLE), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +---- +false true true NULL + +query BBBB +SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL::FLOAT) ---- false true true NULL diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 3b2b219244f55..8f787254c0967 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -25,7 +25,7 @@ ## with metadata in SQL. query IT -select * from table_with_metadata; +select id, name from table_with_metadata; ---- 1 NULL NULL bar @@ -58,5 +58,115 @@ WHERE "data"."id" = "samples"."id"; 1 3 + + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select count(distinct name) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select approx_median(distinct id) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +statement ok +select array_agg(distinct id) from table_with_metadata; + +query I +select distinct id from table_with_metadata order by id; +---- +1 +3 +NULL + +query I +select count(id) from table_with_metadata; +---- +2 + +query I +select count(id) cnt from table_with_metadata group by name order by cnt; +---- +0 +1 +1 + + + +# Regression test: missing schema metadata, when aggregate on cross join +query I +SELECT count("data"."id") +FROM + ( + SELECT "id" FROM "table_with_metadata" + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples"; +---- +6 + +# Regression test: missing field metadata, from the NULL field on the left side of the union +query ITT +(SELECT id, NULL::string as name, l_name FROM "table_with_metadata") + UNION +(SELECT id, name, NULL::string as l_name FROM "table_with_metadata") +ORDER BY id, name, l_name; +---- +1 NULL NULL +3 baz NULL +3 NULL l_baz +NULL bar NULL +NULL NULL l_bar + +# Regression test: missing field metadata from left side of the union when right side is chosen +query T +select name from ( + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL + +# Regression test: missing schema metadata from union when schema with metadata isn't the first one +# and also ensure it works fine with multiple unions +query T +select name from ( + SELECT NULL::string as name + UNION ALL + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL + +query P rowsort +SELECT ts +FROM (( + SELECT now() AS ts + FROM table_with_metadata +) UNION ALL ( + SELECT ts + FROM table_with_metadata +)) +GROUP BY ts +ORDER BY ts +LIMIT 1; +---- +2020-09-08T13:42:29.190855123Z + + statement ok drop table table_with_metadata; diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index f53363b6eb38c..6cc7ee0403f28 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -958,6 +958,24 @@ drop table foo; statement ok drop table ambiguity_test; +## reproducer for https://github.com/apache/datafusion/issues/12446 +# Ensure union ordering calculations with constants can be optimized + +statement ok +create table t(a0 int, a int, b int, c int) as values (1, 2, 3, 4), (5, 6, 7, 8); + +# expect this query to run successfully, not error +query III +select * from (select c, a, NULL::int as a0 from t order by a, c) t1 +union all +select * from (select c, NULL::int as a, a0 from t order by a0, c) t2 +order by c, a, a0, b +limit 2; +---- +4 2 NULL +4 NULL 1 + + # Casting from numeric to string types breaks the ordering statement ok CREATE EXTERNAL TABLE ordered_table ( @@ -1189,3 +1207,48 @@ physical_plan 02)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 03)----SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + + +# Test: inputs into union with different orderings +query TT +explain select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +logical_plan +01)Projection: t1.b, t1.c, t1.a, t1.a0 +02)--Sort: t1.d ASC NULLS LAST, t1.c ASC NULLS LAST, t1.a ASC NULLS LAST, t1.a0 ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=2 +03)----Union +04)------SubqueryAlias: t1 +05)--------Projection: ordered_table.b, ordered_table.c, ordered_table.a, Int32(NULL) AS a0, ordered_table.d +06)----------TableScan: ordered_table projection=[a, b, c, d] +07)------SubqueryAlias: t2 +08)--------Projection: ordered_table.b, ordered_table.c, Int32(NULL) AS a, ordered_table.a0, ordered_table.d +09)----------TableScan: ordered_table projection=[a0, b, c, d] +physical_plan +01)ProjectionExec: expr=[b@0 as b, c@1 as c, a@2 as a, a0@3 as a0] +02)--SortPreservingMergeExec: [d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], fetch=2 +03)----UnionExec +04)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------ProjectionExec: expr=[b@1 as b, c@2 as c, a@0 as a, NULL as a0, d@3 as d] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true +07)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +08)--------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] +09)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true + +# Test: run the query from above +query IIII +select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +0 0 0 NULL +0 0 NULL 1 + + +statement ok +drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index f8b163adc7967..bf68a18511373 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -348,3 +348,204 @@ DROP TABLE list_columns; # Clean up statement ok DROP TABLE listing_table; + +### Tests for binary_ar_string + +# This scenario models the case where a column has been stored in parquet +# "binary" column (without a String logical type annotation) +# this is the case with the `hits_partitioned` ClickBench datasets +# see https://github.com/apache/datafusion/issues/12788 + +## Create a table with a binary column + +query I +COPY ( + SELECT + arrow_cast(string_col, 'Binary') as binary_col, + arrow_cast(string_col, 'LargeBinary') as largebinary_col, + arrow_cast(string_col, 'BinaryView') as binaryview_col + FROM src_table + ) +TO 'test_files/scratch/parquet/binary_as_string.parquet' +STORED AS PARQUET; +---- +9 + +# Test 1: Read table with default options +statement ok +CREATE EXTERNAL TABLE binary_as_string_default +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' + +# NB the data is read and displayed as binary +query T?T?T? +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_default; +---- +Binary 616161 Binary 616161 Binary 616161 +Binary 626262 Binary 626262 Binary 626262 +Binary 636363 Binary 636363 Binary 636363 +Binary 646464 Binary 646464 Binary 646464 +Binary 656565 Binary 656565 Binary 656565 +Binary 666666 Binary 666666 Binary 666666 +Binary 676767 Binary 676767 Binary 676767 +Binary 686868 Binary 686868 Binary 686868 +Binary 696969 Binary 696969 Binary 696969 + +# Run an explain plan to show the cast happens in the plan (a CAST is needed for the predicates) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_default + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%") +02)--TableScan: binary_as_string_default projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% + + +statement ok +DROP TABLE binary_as_string_default; + +## Test 2: Read table using the binary_as_string option + +statement ok +CREATE EXTERNAL TABLE binary_as_string_option +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ('binary_as_string' 'true'); + +# NB the data is read and displayed as string +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_option; +---- +Utf8 aaa Utf8 aaa Utf8 aaa +Utf8 bbb Utf8 bbb Utf8 bbb +Utf8 ccc Utf8 ccc Utf8 ccc +Utf8 ddd Utf8 ddd Utf8 ddd +Utf8 eee Utf8 eee Utf8 eee +Utf8 fff Utf8 fff Utf8 fff +Utf8 ggg Utf8 ggg Utf8 ggg +Utf8 hhh Utf8 hhh Utf8 hhh +Utf8 iii Utf8 iii Utf8 iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_option + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_option.binary_col LIKE Utf8("%a%") AND binary_as_string_option.largebinary_col LIKE Utf8("%a%") AND binary_as_string_option.binaryview_col LIKE Utf8("%a%") +02)--TableScan: binary_as_string_option projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_option.binary_col LIKE Utf8("%a%"), binary_as_string_option.largebinary_col LIKE Utf8("%a%"), binary_as_string_option.binaryview_col LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +DROP TABLE binary_as_string_option; + +## Test 3: Read table with binary_as_string option AND schema_force_view_types + +statement ok +CREATE EXTERNAL TABLE binary_as_string_both +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ( + 'binary_as_string' 'true', + 'schema_force_view_types' 'true' +); + +# NB the data is read and displayed a StringView +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_both; +---- +Utf8View aaa Utf8View aaa Utf8View aaa +Utf8View bbb Utf8View bbb Utf8View bbb +Utf8View ccc Utf8View ccc Utf8View ccc +Utf8View ddd Utf8View ddd Utf8View ddd +Utf8View eee Utf8View eee Utf8View eee +Utf8View fff Utf8View fff Utf8View fff +Utf8View ggg Utf8View ggg Utf8View ggg +Utf8View hhh Utf8View hhh Utf8View hhh +Utf8View iii Utf8View iii Utf8View iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_both + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_both.binary_col LIKE Utf8View("%a%") AND binary_as_string_both.largebinary_col LIKE Utf8View("%a%") AND binary_as_string_both.binaryview_col LIKE Utf8View("%a%") +02)--TableScan: binary_as_string_both projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_both.binary_col LIKE Utf8View("%a%"), binary_as_string_both.largebinary_col LIKE Utf8View("%a%"), binary_as_string_both.binaryview_col LIKE Utf8View("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +drop table binary_as_string_both; + +# Read a parquet file with binary data in a FixedSizeBinary column + +# by default, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet'; + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3; +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; + + +# even with the binary_as_string option set, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet' +OPTIONS ('binary_as_string' 'true'); + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3 +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index eedc3ddb6d59c..800026dd766d2 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -460,6 +460,313 @@ SELECT NULL not iLIKE NULL; ---- NULL +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 4c86312f9e51a..858e421062213 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..87], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:87..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:174..261], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:261..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..87], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:87..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:174..261], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:261..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ physical_plan 02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..172], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:172..338, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..178], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:178..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:174..342, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..180], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:180..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..169], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..173], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:173..347], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:169..338]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..171], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..175], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:175..351], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:171..342]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 3b9c9a16042ca..145172f31fd7d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -536,6 +536,37 @@ select log(a, 64) a, log(b), log(10, b) from signed_integers; NaN 2 2 NaN 4 4 +# log overloaded base 10 float64 and float32 casting scalar +query RR rowsort +select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b; +---- +1 2 + +# log overloaded base 10 float64 and float32 casting with columns +query RR rowsort +select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers; +---- +0.301029995664 NaN +0.602059991328 NULL +NaN 2 +NaN 4 + +# log float64 and float32 casting scalar +query RR rowsort +select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b; +---- +3 4 + +# log float64 and float32 casting with columns +query RR rowsort +select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers; +---- +1 NaN +2 NULL +NaN 3.321928 +NaN 6.643856 + + ## log10 # log10 scalar function @@ -1526,6 +1557,9 @@ NULL NULL query error DataFusion error: Error during planning: Negation only supports numeric, interval and timestamp types SELECT -'100' +query error DataFusion error: Error during planning: Unary operator '\+' only supports numeric, interval and timestamp types +SELECT +true + statement ok drop table test_boolean @@ -1906,11 +1940,9 @@ select position('' in '') ---- 1 - -query error POSITION function can only accept strings +query error DataFusion error: Error during planning: Error during planning: Int64 and Int64 are not coercible to a common string select position(1 in 1) - query I select strpos('abc', 'c'); ---- diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 5df5f313af3c4..f2ab4135aaa76 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -348,8 +348,11 @@ VALUES (1),() statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1 VALUES (1),(1,2) -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0 +query I VALUES (1),('2') +---- +1 +2 query R VALUES (1),(2.0) @@ -357,8 +360,11 @@ VALUES (1),(2.0) 1 2 -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1 +query II VALUES (1,2), (1,'2') +---- +1 2 +1 2 query IT VALUES (1,'a'),(NULL,'b'),(3,'c') @@ -552,7 +558,7 @@ EXPLAIN SELECT * FROM ((SELECT column1 FROM foo) "T1" CROSS JOIN (SELECT column2 ---- logical_plan 01)SubqueryAlias: F -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: T1 04)------TableScan: foo projection=[column1] 05)----SubqueryAlias: T2 @@ -575,9 +581,32 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. +statement error Error during planning: LIMIT must be >= 0, '-1' was provided select * from (select 1 a union all select 2) b order by a limit -1; +statement error Error during planning: OFFSET must be >=0, '-1' was provided +select * from (select 1 a union all select 2) b order by a offset -1; + +statement error Unsupported LIMIT expression +select * from (values(1),(2)) limit (select 1); + +statement error Unsupported OFFSET expression +select * from (values(1),(2)) offset (select 1); + +# disallow non-integer limit/offset +statement error Expected LIMIT to be an integer or null, but got Float64 +select * from (values(1),(2)) limit 0.5; + +statement error Expected OFFSET to be an integer or null, but got Utf8 +select * from (values(1),(2)) offset '1'; + +# test with different integer types +query I +select * from (values (1), (2), (3), (4)) limit 2::int OFFSET 1::tinyint +---- +2 +3 + # select limit with basic arithmetic query I select * from (select 1 a union all select 2) b order by a limit 1+1; @@ -591,13 +620,38 @@ select * from (values (1)) LIMIT 10*100; ---- 1 -# More complex expressions in the limit is not supported yet. -# See issue: https://github.com/apache/datafusion/issues/9821 -statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +# select limit with complex arithmetic +query I select * from (values (1)) LIMIT 100/10; +---- +1 -# More complex expressions in the limit is not supported yet. -statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +# test constant-folding of LIMIT expr +query I +select * from (values (1), (2), (3), (4)) LIMIT abs(-4) + 4 / -2; -- LIMIT 2 +---- +1 +2 + +# test constant-folding of OFFSET expr +query I +select * from (values (1), (2), (3), (4)) OFFSET abs(-4) + 4 / -2; -- OFFSET 2 +---- +3 +4 + +# test constant-folding of LIMIT and OFFSET +query I +select * from (values (1), (2), (3), (4)) + -- LIMIT 2 + LIMIT abs(-4) + -1 * 2 + -- OFFSET 1 + OFFSET case when 1 < 2 then 1 else 0 end; +---- +2 +3 + +statement error Schema error: No field named column1. select * from (values (1)) LIMIT cast(column1 as tinyint); # select limit clause @@ -607,6 +661,13 @@ select * from (select 1 a union all select 2) b order by a limit null; 1 2 +# offset null takes no effect +query I +select * from (select 1 a union all select 2) b order by a offset null; +---- +1 +2 + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit 0; diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index ebd53e9690fc2..f4cc888d6b8e7 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -126,22 +126,24 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL +# Uncomment when filtered FULL moved # full join with join filter -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ----- -Alice 100 NULL NULL -Alice 50 Alice 2 -Bob 1 NULL NULL -NULL NULL Alice 1 - -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 NULL NULL -Bob 1 NULL NULL +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +#---- +#Alice 100 NULL NULL +#Alice 50 Alice 2 +#Bob 1 NULL NULL +#NULL NULL Alice 1 + +# Uncomment when filtered FULL moved +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 NULL NULL +#Bob 1 NULL NULL statement ok DROP TABLE t1; @@ -493,6 +495,7 @@ select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1. ) order by 1, 2 ---- + query II select * from ( with @@ -576,7 +579,7 @@ query II select * from ( with t1 as ( - select 11 a, 12 b), + select 11 a, 12 b), t2 as ( select 11 a, 12 c union all select 11 a, 11 c union all diff --git a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt index 9d24608167095..c181f613ee9a9 100644 --- a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt +++ b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt @@ -37,6 +37,22 @@ select arrow_cast(col1, 'Dictionary(Int32, Utf8)') as c1 from test_substr_base; statement ok drop table test_source +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/init_data.slt.part b/datafusion/sqllogictest/test_files/string/init_data.slt.part index d99401f10d205..096e3bb3b330c 100644 --- a/datafusion/sqllogictest/test_files/string/init_data.slt.part +++ b/datafusion/sqllogictest/test_files/string/init_data.slt.part @@ -30,4 +30,3 @@ statement ok create table test_substr_base ( col1 VARCHAR ) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL); - diff --git a/datafusion/sqllogictest/test_files/string/large_string.slt b/datafusion/sqllogictest/test_files/string/large_string.slt index a2e570073ff6d..8d8a5711bdb8d 100644 --- a/datafusion/sqllogictest/test_files/string/large_string.slt +++ b/datafusion/sqllogictest/test_files/string/large_string.slt @@ -44,17 +44,20 @@ Raphael R datafusionДатаФусион аФус NULL R NULL 🔥 # TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; ---- -true false -false false -false true -NULL NULL +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL # # common test for string-like functions and operators diff --git a/datafusion/sqllogictest/test_files/string/string.slt b/datafusion/sqllogictest/test_files/string/string.slt index bc923d5e12c3b..e84342abd3dff 100644 --- a/datafusion/sqllogictest/test_files/string/string.slt +++ b/datafusion/sqllogictest/test_files/string/string.slt @@ -35,17 +35,20 @@ create table test_substr as select arrow_cast(col1, 'Utf8') as c1 from test_substr_base; # TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; ---- -true false -false false -false true -NULL NULL +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL # # common test for string-like functions and operators diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 24e03fdb71844..5d847747693d8 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -167,3 +167,652 @@ query D select make_date(arrow_cast('2024', 'Utf8View'), arrow_cast('01', 'Utf8View'), arrow_cast('23', 'Utf8View')) ---- 2024-01-23 + +query I +SELECT character_length('') +---- +0 + +query I +SELECT character_length('chars') +---- +5 + +query I +SELECT character_length('josé') +---- +4 + +query I +SELECT character_length(NULL) +---- +NULL + +query B +SELECT ends_with('foobar', 'bar') +---- +true + +query B +SELECT ends_with('foobar', 'foo') +---- +false + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query I +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query I +SELECT levenshtein(NULL, NULL) +---- +NULL + + +query T +SELECT lpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', -1) +---- +(empty) + +query T +SELECT lpad('hi', 0) +---- +(empty) + +query T +SELECT lpad('hi', 21, 'abcdef') +---- +abcdefabcdefabcdefahi + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', 5, NULL) +---- +NULL + +query T +SELECT lpad('hi', 5) +---- + hi + +query T +SELECT lpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT lpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT lpad('xyxhi', 3) +---- +xyx + +query T +SELECT lpad(NULL, 0) +---- +NULL + +query T +SELECT lpad(NULL, 5, 'xy') +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT repeat('foo', 3) +---- +foofoofoo + +query T +SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) +---- +foofoofoo + + +query T +SELECT replace('foobar', 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) +---- +foohello + + +query T +SELECT reverse('abcde') +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'LargeUtf8')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Utf8View')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) +---- +edcba + +query T +SELECT reverse('loẅks') +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'LargeUtf8')) +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'Utf8View')) +---- +sk̈wol + +query T +SELECT reverse(NULL) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'LargeUtf8')) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'Utf8View')) +---- +NULL + + +query I +SELECT strpos('abc', 'c') +---- +3 + +query I +SELECT strpos('josé', 'é') +---- +4 + +query I +SELECT strpos('joséésoj', 'so') +---- +6 + +query I +SELECT strpos('joséésoj', 'abc') +---- +0 + +query I +SELECT strpos(NULL, 'abc') +---- +NULL + +query I +SELECT strpos('joséésoj', NULL) +---- +NULL + + + +query T +SELECT rpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', -1) +---- +(empty) + +query T +SELECT rpad('hi', 0) +---- +(empty) + +query T +SELECT rpad('hi', 21, 'abcdef') +---- +hiabcdefabcdefabcdefa + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', 5, NULL) +---- +NULL + +query T +SELECT rpad('hi', 5) +---- +hi + +query T +SELECT rpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT rpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT rpad('xyxhi', 3) +---- +xyx + +# test for rpad with largeutf8 and utf8View + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +hixyx + +query T +SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + +query I +SELECT char_length('') +---- +0 + +query I +SELECT char_length('chars') +---- +5 + +query I +SELECT char_length('josé') +---- +4 + +query I +SELECT char_length(NULL) +---- +NULL + +# Test substring_index using '.' as delimiter +# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results +query TIT +SELECT str, n, substring_index(str, '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using '.' as delimiter with utf8view +query TIT +SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using 'ac' as delimiter +query TIT +SELECT str, n, substring_index(str, 'ac', n) AS c FROM + (VALUES + -- input string does not contain the delimiter + ROW('arrow'), + -- input string contains the delimiter + ROW('arrow.apache.org') + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(-1), + ROW(-2) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +arrow.apache.org -2 arrow.apache.org +arrow.apache.org -1 he.org +arrow.apache.org 1 arrow.ap +arrow.apache.org 2 arrow.apache.org +arrow -2 arrow +arrow -1 arrow +arrow 1 arrow +arrow 2 arrow + +# Test substring_index with NULL values +query TTTT +SELECT + substring_index(NULL, '.', 1), + substring_index('arrow.apache.org', NULL, 1), + substring_index('arrow.apache.org', '.', NULL), + substring_index(NULL, NULL, NULL) +---- +NULL NULL NULL NULL + +# Test substring_index with empty strings +query TT +SELECT + -- input string is empty + substring_index('', '.', 1), + -- delimiter is empty + substring_index('arrow.apache.org', '', 1) +---- +(empty) (empty) + +# Test substring_index with 0 occurrence +query T +SELECT substring_index('arrow.apache.org', 'ac', 0) +---- +(empty) + +# Test substring_index with large occurrences +query TT +SELECT + -- i64::MIN + substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, + -- i64::MAX + substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; +---- +arrow.apache.org arrow.apache.org + +# Test substring_index issue https://github.com/apache/datafusion/issues/9472 +query TTT +SELECT + url, + substring_index(url, '.', 1) AS subdomain, + substring_index(url, '.', -1) AS tld +FROM + (VALUES ROW('docs.apache.com'), + ROW('community.influxdata.com'), + ROW('arrow.apache.org') + ) data(url) +---- +docs.apache.com docs com +community.influxdata.com community com +arrow.apache.org arrow org + + +# find_in_set tests +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query I +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query I +SELECT find_in_set(NULL, NULL) +---- +NULL + +# find_in_set tests with utf8view +query I +SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +1 + +query I +SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +0 + + +query T +SELECT split_part('foo_bar', '_', 2) +---- +bar + +query T +SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) +---- +bar + +# test largeutf8, utf8view for split_part +query T +SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); +---- +view + +query T +SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) +---- +_split + +query T +SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) +---- +_banana + +query T +SELECT split_part(NULL, '_', 2) +---- +NULL + +query B +SELECT starts_with('foobar', 'foo') +---- +true + +query B +SELECT starts_with('foobar', 'bar') +---- +false diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index 96d5ddbd992ca..dc5626b7d5734 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -642,18 +642,16 @@ true false false true NULL NULL -# TODO: DictionaryString does not support ~* operator. Enable this after fixing the issue -# see issue: https://github.com/apache/datafusion/issues/12618 -#query BB -#SELECT -# ascii_1 ~* '^a.{3}e', -# unicode_1 ~* '^d.*Фу' -#FROM test_basic_operator; -#---- -#true false -#false false -#false true -#NULL NULL +query BB +SELECT + ascii_1 ~* '^a.{3}e', + unicode_1 ~* '^d.*Фу' +FROM test_basic_operator; +---- +true false +false false +false true +NULL NULL query BB SELECT @@ -694,3 +692,293 @@ Andrew nice Andrew and X datafusion📊🔥 cool datafusion📊🔥 and 🔥 And Xiangpeng nice Xiangpeng and Xiangpeng datafusion数据融合 cool datafusion数据融合 and datafusion数据融合 Xiangpeng 🔥 datafusion数据融合 Raphael nice Raphael and R datafusionДатаФусион cool datafusionДатаФусион and аФус Raphael 🔥 datafusionДатаФусион NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LIKE / ILIKE +# -------------------------------------- + +# TODO: StringView has wrong behavior for LIKE/ILIKE. Enable this after fixing the issue +# see issue: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +#query TTBBBB +#select ascii_1, unicode_1, +# ascii_1 like 'An%' as ascii_like, +# unicode_1 like '%ion数据%' as unicode_like, +# ascii_1 ilike 'An%' as ascii_ilike, +# unicode_1 ilike '%ion数据%' as unicode_ilik +#from test_basic_operator; +#---- +#Andrew datafusion📊🔥 true false true false +#Xiangpeng datafusion数据融合 false true false true +#Raphael datafusionДатаФусион false false false false +#NULL NULL NULL NULL NULL NULL + +# Test pattern without wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An' as ascii_like, + unicode_1 like 'ion数据' as unicode_like, + ascii_1 ilike 'An' as ascii_ilike, + unicode_1 ilike 'ion数据' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 false false false false +Xiangpeng datafusion数据融合 false false false false +Raphael datafusionДатаФусион false false false false +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test CHARACTER_LENGTH +# -------------------------------------- + +query II +SELECT + CHARACTER_LENGTH(ascii_1), + CHARACTER_LENGTH(unicode_1) +FROM + test_basic_operator +---- +6 12 +9 14 +7 20 +NULL NULL + +# -------------------------------------- +# Test Start_With +# -------------------------------------- + +query BBBB +SELECT + STARTS_WITH(ascii_1, 'And'), + STARTS_WITH(unicode_1, 'data'), + STARTS_WITH(ascii_1, NULL), + STARTS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true true NULL NULL +false true NULL NULL +false true NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test ENDS_WITH +# -------------------------------------- + +query BBBB +SELECT + ENDS_WITH(ascii_1, 'w'), + ENDS_WITH(unicode_1, 'ион'), + ENDS_WITH(ascii_1, NULL), + ENDS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true false NULL NULL +false false NULL NULL +false true NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LEVENSHTEIN +# -------------------------------------- + +query IIII +SELECT + LEVENSHTEIN(ascii_1, 'Andrew'), + LEVENSHTEIN(unicode_1, 'datafusion数据融合'), + LEVENSHTEIN(ascii_1, NULL), + LEVENSHTEIN(unicode_1, NULL) +FROM test_basic_operator; +---- +0 4 NULL NULL +7 0 NULL NULL +6 10 NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LPAD +# -------------------------------------- + +query TTTT +SELECT + LPAD(ascii_1, 20, 'x'), + LPAD(ascii_1, 20, NULL), + LPAD(unicode_1, 20, '🔥'), + LPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +xxxxxxxxxxxxxxAndrew NULL 🔥🔥🔥🔥🔥🔥🔥🔥datafusion📊🔥 NULL +xxxxxxxxxxxXiangpeng NULL 🔥🔥🔥🔥🔥🔥datafusion数据融合 NULL +xxxxxxxxxxxxxRaphael NULL datafusionДатаФусион NULL +NULL NULL NULL NULL + +query TT +SELECT + LPAD(ascii_1, 20), + LPAD(unicode_1, 20) +FROM test_basic_operator; +---- + Andrew datafusion📊🔥 + Xiangpeng datafusion数据融合 + Raphael datafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test RPAD +# -------------------------------------- + +query TTTT +SELECT + RPAD(ascii_1, 20, 'x'), + RPAD(ascii_1, 20, NULL), + RPAD(unicode_1, 20, '🔥'), + RPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +Andrewxxxxxxxxxxxxxx NULL datafusion📊🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +Xiangpengxxxxxxxxxxx NULL datafusion数据融合🔥🔥🔥🔥🔥🔥 NULL +Raphaelxxxxxxxxxxxxx NULL datafusionДатаФусион NULL +NULL NULL NULL NULL + +query TT +SELECT + RPAD(ascii_1, 20), + RPAD(unicode_1, 20) +FROM test_basic_operator; +---- +Andrew datafusion📊🔥 +Xiangpeng datafusion数据融合 +Raphael datafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test REGEXP_LIKE +# -------------------------------------- + +query BBBBBBBB +SELECT + -- without flags + REGEXP_LIKE(ascii_1, 'an'), + REGEXP_LIKE(unicode_1, 'таФ'), + REGEXP_LIKE(ascii_1, NULL), + REGEXP_LIKE(unicode_1, NULL), + -- with flags + REGEXP_LIKE(ascii_1, 'AN', 'i'), + REGEXP_LIKE(unicode_1, 'ТаФ', 'i'), + REGEXP_LIKE(ascii_1, NULL, 'i'), + REGEXP_LIKE(unicode_1, NULL, 'i') + FROM test_basic_operator; +---- +false false NULL NULL true false NULL NULL +true false NULL NULL true false NULL NULL +false true NULL NULL false true NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REGEXP_MATCH +# -------------------------------------- + +query ???????? +SELECT + -- without flags + REGEXP_MATCH(ascii_1, 'an'), + REGEXP_MATCH(unicode_1, 'ТаФ'), + REGEXP_MATCH(ascii_1, NULL), + REGEXP_MATCH(unicode_1, NULL), + -- with flags + REGEXP_MATCH(ascii_1, 'AN', 'i'), + REGEXP_MATCH(unicode_1, 'таФ', 'i'), + REGEXP_MATCH(ascii_1, NULL, 'i'), + REGEXP_MATCH(unicode_1, NULL, 'i') +FROM test_basic_operator; +---- +NULL NULL NULL NULL [An] NULL NULL NULL +[an] NULL NULL NULL [an] NULL NULL NULL +NULL NULL NULL NULL NULL [таФ] NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REPEAT +# -------------------------------------- + +query TT +SELECT + REPEAT(ascii_1, 3), + REPEAT(unicode_1, 3) +FROM test_basic_operator; +---- +AndrewAndrewAndrew datafusion📊🔥datafusion📊🔥datafusion📊🔥 +XiangpengXiangpengXiangpeng datafusion数据融合datafusion数据融合datafusion数据融合 +RaphaelRaphaelRaphael datafusionДатаФусионdatafusionДатаФусионdatafusionДатаФусион +NULL NULL + +# -------------------------------------- +# Test SPLIT_PART +# -------------------------------------- + +query TTTTTT +SELECT + SPLIT_PART(ascii_1, 'e', 1), + SPLIT_PART(ascii_1, 'e', 2), + SPLIT_PART(ascii_1, NULL, 1), + SPLIT_PART(unicode_1, 'и', 1), + SPLIT_PART(unicode_1, 'и', 2), + SPLIT_PART(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr w NULL datafusion📊🔥 (empty) NULL +Xiangp ng NULL datafusion数据融合 (empty) NULL +Rapha l NULL datafusionДатаФус он NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REVERSE +# -------------------------------------- + +query TT +SELECT + REVERSE(ascii_1), + REVERSE(unicode_1) +FROM test_basic_operator; +---- +werdnA 🔥📊noisufatad +gnepgnaiX 合融据数noisufatad +leahpaR ноисуФатаДnoisufatad +NULL NULL + +# -------------------------------------- +# Test STRPOS +# -------------------------------------- + +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test SUBSTR_INDEX +# -------------------------------------- + +query TTTTTT +SELECT + SUBSTR_INDEX(ascii_1, 'e', 1), + SUBSTR_INDEX(ascii_1, 'ang', 1), + SUBSTR_INDEX(ascii_1, NULL, 1), + SUBSTR_INDEX(unicode_1, 'и', 1), + SUBSTR_INDEX(unicode_1, '据融', 1), + SUBSTR_INDEX(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr Andrew NULL datafusion📊🔥 datafusion📊🔥 NULL +Xiangp Xi NULL datafusion数据融合 datafusion数 NULL +Rapha Raphael NULL datafusionДатаФус datafusionДатаФусион NULL +NULL NULL NULL NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index e7b55c9c1c8c7..997dca7191472 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -37,19 +37,6 @@ select arrow_cast(col1, 'Utf8View') as c1 from test_substr_base; statement ok drop table test_source -# TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; ----- -true false -false false -false true -NULL NULL - # # common test for string-like functions and operators # @@ -92,6 +79,29 @@ FROM test_source; statement ok drop table test_source +######## +## StringView Function test +######## + +query I +select octet_length(column1_utf8view) from test; +---- +6 +9 +7 +NULL + +query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View +select bit_length(column1_utf8view) from test; + +query T +select btrim(column1_large_utf8) from test; +---- +Andrew +Xiangpeng +Raphael +NULL + ######## ## StringView to Other Types column ######## @@ -299,9 +309,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(__common_expr_1, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(__common_expr_1, CAST(test.column2_large_utf8 AS Utf8View)) AS c4 -02)--Projection: CAST(test.column1_utf8 AS Utf8View) AS __common_expr_1, test.column1_utf8, test.column2_utf8, test.column2_large_utf8, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] +01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] query BBB SELECT @@ -591,7 +600,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: contains(test.column1_utf8view, Utf8("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, test.column2_large_utf8) AS c3, contains(test.column1_utf8, test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(test.column1_utf8, test.column2_large_utf8) AS c6, contains(test.column1_large_utf8, test.column1_utf8view) AS c7, contains(test.column1_large_utf8, test.column2_utf8) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 +01)Projection: contains(test.column1_utf8view, Utf8View("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3, contains(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c6, contains(CAST(test.column1_large_utf8 AS Utf8View), test.column1_utf8view) AS c7, contains(test.column1_large_utf8, CAST(test.column2_utf8 AS LargeUtf8)) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 02)--TableScan: test projection=[column1_utf8, column2_utf8, column1_large_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] ## Ensure no casts for ENDS_WITH @@ -713,7 +722,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_like(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_MATCH @@ -835,7 +844,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: strpos(test.column1_utf8view, Utf8View("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SUBSTR @@ -892,6 +901,26 @@ logical_plan 01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no casts for to_date +query TT +EXPLAIN SELECT + to_date(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +logical_plan +01)Projection: to_date(test.column1_utf8view, Utf8("a,b,c,d")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for to_timestamp +query TT +EXPLAIN SELECT + to_timestamp(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +logical_plan +01)Projection: to_timestamp(test.column1_utf8view, Utf8("a,b,c,d")) AS c +02)--TableScan: test projection=[column1_utf8view] + ## Ensure no casts for binary operators # `~` operator (regex match) query TT diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index f3ac6549ad066..7596b820c688b 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -282,3 +282,316 @@ drop table values; statement ok drop table struct_values; + +statement ok +CREATE OR REPLACE VIEW complex_view AS +SELECT { + 'user': { + 'info': { + 'personal': { + 'name': 'John Doe', + 'age': 30, + 'email': 'john.doe@example.com' + }, + 'address': { + 'street': '123 Main St', + 'city': 'Anytown', + 'country': 'Countryland', + 'coordinates': [40.7128, -74.0060] + } + }, + 'preferences': { + 'theme': 'dark', + 'notifications': true, + 'languages': ['en', 'es', 'fr'] + }, + 'stats': { + 'logins': 42, + 'last_active': '2023-09-15', + 'scores': [85, 92, 78, 95], + 'achievements': { + 'badges': ['early_bird', 'top_contributor'], + 'levels': { + 'beginner': true, + 'intermediate': true, + 'advanced': false + } + } + } + }, + 'metadata': { + 'version': '1.0', + 'created_at': '2023-09-01T12:00:00Z' + }, + 'deep_nested': { + 'level1': { + 'level2': { + 'level3': { + 'level4': { + 'level5': { + 'level6': { + 'level7': { + 'level8': { + 'level9': { + 'level10': 'You reached the bottom!' + } + } + } + } + } + } + } + } + } + } +} AS complex_data; + +query T +SELECT complex_data.user.info.personal.name FROM complex_view; +---- +John Doe + +query I +SELECT complex_data.user.info.personal.age FROM complex_view; +---- +30 + +query T +SELECT complex_data.user.info.address.city FROM complex_view; +---- +Anytown + +query T +SELECT complex_data.user.preferences.languages[2] FROM complex_view; +---- +es + +query T +SELECT complex_data.deep_nested.level1.level2.level3.level4.level5.level6.level7.level8.level9.level10 FROM complex_view; +---- +You reached the bottom! + +statement ok +drop view complex_view; + +# struct with different keys r1 and r2 is not valid +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +# Expect same keys for struct type but got mismatched pair r1,c and r2,c +query error +select [a, b] from t; + +statement ok +drop table t; + +# struct with the same key +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query ? +select [a, b] from t; +---- +[{r: red, c: 1.0}, {r: blue, c: 2.3}] + +statement ok +drop table t; + +# Test row alias + +query ? +select row('a', 'b'); +---- +{c0: a, c1: b} + +################################## +# Switch Dialect to DuckDB +################################## + +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE TABLE struct_values ( + s1 struct(a int, b varchar), + s2 struct(a int, b varchar) +) AS VALUES + (row(1, 'red'), row(1, 'string1')), + (row(2, 'blue'), row(2, 'string2')), + (row(3, 'green'), row(3, 'string3')) +; + +statement ok +drop table struct_values; + +statement ok +create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values ( + row('red', 2), + row('blue', 2.3) +); + +query ?? +select * from t; +---- +{r: red, b: 2} {r: blue, b: 2.3} + +query T +select arrow_typeof(c1) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +query T +select arrow_typeof(c2) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +create table t as values({r: 'a', c: 1}), ({r: 'b', c: 2.3}); + +query ? +select * from t; +---- +{c0: a, c1: 1.0} +{c0: b, c1: 2.3} + +query T +select arrow_typeof(column1) from t; +---- +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Float64 type +create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); + +################################## +## Test Coalesce with Struct +################################## + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (row(2, 'blue'), row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +{a: 2, b: blue} +{a: 3, b: green} + +query T +select arrow_typeof(coalesce(s1, s2)) from t; +---- +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (null, row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1, s2) from t; +---- +{a: 1.0, b: red} +{a: 2.2, b: string2} +{a: 3.0, b: green} + +query T +select arrow_typeof(coalesce(s1, s2)) from t; +---- +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +# row() with incorrect order +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values + (row('red', 1), row(2.3, 'blue')), + (row('purple', 1), row('green', 2.3)); + +# out of order struct literal +# TODO: This query should not fail +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'b' to value of Int32 type +create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); + +################################## +## Test Array of Struct +################################## + +query ? +select [{r: 'a', c: 1}, {r: 'b', c: 2}]; +---- +[{r: a, c: 1}, {r: b, c: 2}] + +# Can't create a list of struct with different field types +query error +select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; + +statement ok +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +statement ok +drop table t; + +# create table with different struct type is fine +statement ok +create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); + +# create array with different struct type is not valid +query error +select arrow_typeof([a, b]) from t; + +statement ok +drop table t; + +statement ok +create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(a) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(b) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 30b3631681e74..36de19f1c3aa7 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -208,10 +208,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -242,10 +244,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int * Float64(1))] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -276,10 +280,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 @@ -313,10 +319,12 @@ physical_plan 08)--------------CoalesceBatchesExec: target_batch_size=2 09)----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 10)------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -11)--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -12)------CoalesceBatchesExec: target_batch_size=2 -13)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -14)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------MemoryExec: partitions=1, partition_sizes=[1] +13)------CoalesceBatchesExec: target_batch_size=2 +14)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +15)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 @@ -391,7 +399,7 @@ logical_plan 01)Filter: EXISTS () 02)--Subquery: 03)----Projection: t1.t1_int -04)------Filter: t1.t1_id > t1.t1_int +04)------Filter: t1.t1_int < t1.t1_id 05)--------TableScan: t1 06)--TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -415,13 +423,13 @@ query TT explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int)) ---- logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +01)LeftSemi Join: t1.t1_id = __correlated_sq_2.t2_id 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] -03)--SubqueryAlias: __correlated_sq_1 +03)--SubqueryAlias: __correlated_sq_2 04)----Projection: t2.t2_id -05)------LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int +05)------LeftSemi Join: Filter: __correlated_sq_1.t1_int > t2.t2_int 06)--------TableScan: t2 projection=[t2_id, t2_int] -07)--------SubqueryAlias: __correlated_sq_2 +07)--------SubqueryAlias: __correlated_sq_1 08)----------TableScan: t1 projection=[t1_int] #invalid_scalar_subquery @@ -430,7 +438,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -462,8 +470,8 @@ explain SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1 logical_plan 01)Projection: t1.t1_id, () AS t2_int 02)--Subquery: -03)----Limit: skip=0, fetch=1 -04)------Projection: t2.t2_int +03)----Projection: t2.t2_int +04)------Limit: skip=0, fetch=1 05)--------Filter: t2.t2_int = outer_ref(t1.t1_int) 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_int] @@ -475,8 +483,8 @@ logical_plan 01)Projection: t1.t1_id 02)--Filter: t1.t1_int = () 03)----Subquery: -04)------Limit: skip=0, fetch=1 -05)--------Projection: t2.t2_int +04)------Projection: t2.t2_int +05)--------Limit: skip=0, fetch=1 06)----------Filter: t2.t2_int = outer_ref(t1.t1_int) 07)------------TableScan: t2 08)----TableScan: t1 projection=[t1_id, t1_int] @@ -501,8 +509,18 @@ SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from 44 NULL #non_equal_correlated_scalar_subquery -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated column is not allowed in predicate: t2\.t2_id < outer_ref\(t1\.t1_id\) -SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +# Currently not supported and should not be decorrelated +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS t2_sum +02)--Subquery: +03)----Projection: sum(t2.t2_int) +04)------Aggregate: groupBy=[[]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] +05)--------Filter: t2.t2_id < outer_ref(t1.t1_id) +06)----------TableScan: t2 +07)--TableScan: t1 projection=[t1_id] #aggregated_correlated_scalar_subquery_with_extra_group_by_columns statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns @@ -542,13 +560,13 @@ query TT explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name)) ---- logical_plan -01)Filter: EXISTS () -02)--Subquery: -03)----Projection: Int64(1) -04)------Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) -05)--------TableScan: t1 -06)--------TableScan: t2 -07)--TableScan: t0 projection=[t0_id, t0_name] +01)LeftSemi Join: t0.t0_name = __correlated_sq_2.t1_name +02)--TableScan: t0 projection=[t0_id, t0_name] +03)--SubqueryAlias: __correlated_sq_2 +04)----Projection: t1.t1_name +05)------Inner Join: t1.t1_id = t2.t2_id +06)--------TableScan: t1 projection=[t1_id, t1_name] +07)--------TableScan: t2 projection=[t2_id] #subquery_contains_join_contains_correlated_columns query TT @@ -656,8 +674,8 @@ explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where logical_plan 01)Filter: t1.t1_id IN () 02)--Subquery: -03)----Limit: skip=0, fetch=10 -04)------Projection: t2.t2_id +03)----Projection: t2.t2_id +04)------Limit: skip=0, fetch=10 05)--------Filter: outer_ref(t1.t1_name) = t2.t2_name 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_name] @@ -1028,6 +1046,168 @@ false true true +# in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 + +# exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +33 c 3 +44 d 4 + +# in_subquery_to_join_with_correlated_outer_filter_and_or +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists +04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int] +07)----------SubqueryAlias: __correlated_sq_1 +08)------------TableScan: t3 projection=[t3_id] +09)--------SubqueryAlias: __correlated_sq_2 +10)----------Projection: t2.t2_id, Boolean(true) AS __exists +11)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +11 a 1 +22 b 2 +44 d 4 + +# Nested subqueries +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where exists ( + select * from t2 where t1.t1_id = t2.t2_id OR exists ( + select * from t3 where t2.t2_id = t3.t3_id + ) +) +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 # issue: https://github.com/apache/datafusion/issues/7027 query TTTT rowsort diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt index 17affbc0acadc..e4360a9269ca6 100644 --- a/datafusion/sqllogictest/test_files/subquery_sort.slt +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -93,14 +93,14 @@ logical_plan 02)--Sort: t2.c1 ASC NULLS LAST, t2.c3 ASC NULLS LAST, t2.c9 ASC NULLS LAST 03)----SubqueryAlias: t2 04)------Sort: sink_table.c1 ASC NULLS LAST, sink_table.c3 ASC NULLS LAST, fetch=2 -05)--------Projection: sink_table.c1, RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r, sink_table.c3, sink_table.c9 -06)----------WindowAggr: windowExpr=[[RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------Projection: sink_table.c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r, sink_table.c3, sink_table.c9 +06)----------WindowAggr: windowExpr=[[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 07)------------TableScan: sink_table projection=[c1, c3, c9] physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST,c9@3 ASC NULLS LAST], preserve_partitioning=[false] -03)----ProjectionExec: expr=[c1@0 as c1, RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], has_header=true diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 7a7a8a8703ec4..38c2a66472731 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -84,6 +84,11 @@ select case when current_time() = (now()::bigint % 86400000000000)::time then 'O ---- OK +query B +select now() = current_timestamp; +---- +true + ########## ## Timestamp Handling Tests ########## @@ -2191,6 +2196,14 @@ create table ts_utf8_data(ts varchar(100), format varchar(100)) as values ('1926632005', '%s'), ('2000-01-01T01:01:01+07:00', '%+'); +statement ok +create table ts_largeutf8_data as +select arrow_cast(ts, 'LargeUtf8') as ts, arrow_cast(format, 'LargeUtf8') as format from ts_utf8_data; + +statement ok +create table ts_utf8view_data as +select arrow_cast(ts, 'Utf8View') as ts, arrow_cast(format, 'Utf8View') as format from ts_utf8_data; + # verify timestamp data using tables with formatting options query P SELECT to_timestamp(t.ts, t.format) from ts_utf8_data as t @@ -2201,9 +2214,84 @@ SELECT to_timestamp(t.ts, t.format) from ts_utf8_data as t 2031-01-19T23:33:25 1999-12-31T18:01:01 +query PPPPP +SELECT to_timestamp(t.ts, t.format), + to_timestamp_seconds(t.ts, t.format), + to_timestamp_millis(t.ts, t.format), + to_timestamp_micros(t.ts, t.format), + to_timestamp_nanos(t.ts, t.format) + from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, t.format), + to_timestamp_seconds(t.ts, t.format), + to_timestamp_millis(t.ts, t.format), + to_timestamp_micros(t.ts, t.format), + to_timestamp_nanos(t.ts, t.format) + from ts_utf8view_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + # verify timestamp data using tables with formatting options +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_utf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_utf8view_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +# verify timestamp data using tables with formatting options where at least one column cannot be parsed +query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t + +# verify timestamp data using tables with formatting options where one of the formats is invalid query P -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8_data as t ---- 2020-09-08T12:00:00 2031-01-19T18:33:25 @@ -2211,13 +2299,17 @@ SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S 2031-01-19T23:33:25 1999-12-31T18:01:01 -# verify timestamp data using tables with formatting options where at least one column cannot be parsed -query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t +query P +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 +2031-01-19T18:33:25 +2020-09-08T12:00:00 +2031-01-19T23:33:25 +1999-12-31T18:01:01 -# verify timestamp data using tables with formatting options where one of the formats is invalid query P -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8_data as t +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8view_data as t ---- 2020-09-08T12:00:00 2031-01-19T18:33:25 @@ -2688,6 +2780,11 @@ FROM NULL 01:01:2025 23-59-58 +query T +select to_char('2020-01-01 00:10:20.123'::timestamp at time zone 'America/New_York', '%Y-%m-%d %H:%M:%S.%3f'); +---- +2020-01-01 00:10:20.123 + statement ok drop table formats; diff --git a/datafusion/sqllogictest/test_files/tpch/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/q20.slt.part index 67ea87b6ee61c..177e38e51ca47 100644 --- a/datafusion/sqllogictest/test_files/tpch/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q20.slt.part @@ -58,19 +58,19 @@ order by logical_plan 01)Sort: supplier.s_name ASC NULLS LAST 02)--Projection: supplier.s_name, supplier.s_address -03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey +03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_2.ps_suppkey 04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey 08)------------Filter: nation.n_name = Utf8("CANADA") 09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] -10)------SubqueryAlias: __correlated_sq_1 +10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) -13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey +13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_1.p_partkey 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] -15)--------------SubqueryAlias: __correlated_sq_2 +15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey 17)------------------Filter: part.p_name LIKE Utf8("forest%") 18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part index d2168b0136ba4..2955748160eac 100644 --- a/datafusion/sqllogictest/test_files/tpch/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -72,7 +72,7 @@ logical_plan 14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] 15)----------------Projection: customer.c_acctbal 16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] +17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] physical_plan 01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] 02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index 0f9399cede2ec..43e7c2f7bc250 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -103,11 +103,11 @@ CREATE TABLE orders( ); # union_different_num_columns_error() / UNION -query error Error during planning: Union schemas have different number of fields: query 1 has 1 fields whereas query 2 has 2 fields +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns SELECT order_id FROM orders UNION SELECT customer_id, o_item_id FROM orders # union_different_num_columns_error() / UNION ALL -query error Error during planning: Union schemas have different number of fields: query 1 has 1 fields whereas query 2 has 2 fields +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns SELECT order_id FROM orders UNION ALL SELECT customer_id, o_item_id FROM orders # union_with_different_column_names() diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index a3d0ff4383ae6..fb7afdda2ea82 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -503,9 +503,9 @@ logical_plan 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 14)--------EmptyRelation -15)----Projection: LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt +15)----Projection: lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt 16)------Limit: skip=0, fetch=3 -17)--------WindowAggr: windowExpr=[[LEAD(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +17)--------WindowAggr: windowExpr=[[lead(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 18)----------SubqueryAlias: b 19)------------Projection: Int64(1) AS c1 20)--------------EmptyRelation @@ -528,8 +528,8 @@ physical_plan 16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true 17)------ProjectionExec: expr=[1 as cnt] 18)--------PlaceholderRowExec -19)------ProjectionExec: expr=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -20)--------BoundedWindowAggExec: wdw=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +19)------ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 21)----------ProjectionExec: expr=[1 as c1] 22)------------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 63ca74e9714c7..947eb8630b523 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -33,7 +33,7 @@ AS VALUES statement ok CREATE TABLE nested_unnest_table AS VALUES - (struct('a', 'b', struct('c')), (struct('a', 'b', [10,20])), [struct('a', 'b')]), + (struct('a', 'b', struct('c')), (struct('a', 'b', [10,20])), [struct('a', 'b')]), (struct('d', 'e', struct('f')), (struct('x', 'y', [30,40, 50])), null) ; @@ -511,10 +511,19 @@ x y [30, 40, 50] query error DataFusion error: type_coercion\ncaused by\nThis feature is not implemented: Unnest should be rewritten to LogicalPlan::Unnest before type coercion select sum(unnest(generate_series(1,10))); -## TODO: support unnest as a child expr query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; +query T +select arrow_typeof(unnest(column1)) from unnest_table; +---- +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 ## unnest from a result of a logical plan with limit and offset query I @@ -524,10 +533,19 @@ select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 5 6 -## FIXME: https://github.com/apache/datafusion/issues/11198 query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(unnest_table.column1\)" at position 0 and "UNNEST\(unnest_table.column1\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. select unnest(column1), unnest(column1) from unnest_table; +query II +select unnest(column1), unnest(column1) u1 from unnest_table; +---- +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +12 12 ## the same unnest expr is referened multiple times (unnest is the bottom-most expr) query ??II @@ -625,7 +643,7 @@ NULL [4] [{c0: [2], c1: [[3], [4]]}] 4 [3] [{c0: [2], c1: [[3], [4]]}] NULL [4] [{c0: [2], c1: [[3], [4]]}] -## demonstrate where recursive unnest is impossible +## demonstrate where recursive unnest is impossible ## and need multiple unnesting logical plans ## e.g unnest -> field_access -> unnest query TT @@ -777,6 +795,61 @@ select unnest(unnest(column2)) c2, count(column3) from recursive_unnest_table gr [, 6] 1 NULL 1 -### TODO: group by unnest struct query error DataFusion error: Error during planning: Projection references non\-aggregate values select unnest(column1) c1 from nested_unnest_table group by c1.c0; + +# TODO: this query should work. see issue: https://github.com/apache/datafusion/issues/12794 +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression +select unnest(column1) c1 from nested_unnest_table + +query II??I?? +select unnest(column5), * from unnest_table; +---- +1 2 [1, 2, 3] [7] 1 [13, 14] {c0: 1, c1: 2} +3 4 [4, 5] [8, 9, 10] 2 [15, 16] {c0: 3, c1: 4} +NULL NULL [6] [11, 12] 3 NULL NULL +7 8 [12] [, 42, ] NULL NULL {c0: 7, c1: 8} +NULL NULL NULL NULL 4 [17, 18] NULL + +query TT???? +select unnest(column1), * from nested_unnest_table +---- +a b {c0: c} {c0: a, c1: b, c2: {c0: c}} {c0: a, c1: b, c2: [10, 20]} [{c0: a, c1: b}] +d e {c0: f} {c0: d, c1: e, c2: {c0: f}} {c0: x, c1: y, c2: [30, 40, 50]} NULL + +query ????? +select unnest(unnest(column3)), * from recursive_unnest_table +---- +[1] [[1, 2]] {c0: [1], c1: a} [[[1], [2]], [[1, 1]]] [{c0: [1], c1: [[1, 2]]}] +[2] [[3], [4]] {c0: [2], c1: b} [[[3, 4], [5]], [[, 6], , [7, 8]]] [{c0: [2], c1: [[3], [4]]}] + +statement ok +CREATE TABLE join_table +AS VALUES + (1, 2, 3), + (2, 3, 4), + (4, 5, 6) +; + +query IIIII +select unnest(u.column5), j.* from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 2 3 +3 4 2 3 4 +NULL NULL 4 5 6 + +query II?I? +select unnest(column5), * except (column5, column1) from unnest_table; +---- +1 2 [7] 1 [13, 14] +3 4 [8, 9, 10] 2 [15, 16] +NULL NULL [11, 12] 3 NULL +7 8 [, 42, ] NULL NULL +NULL NULL NULL 4 [17, 18] + +query III +select unnest(u.column5), j.* except(column2, column3) from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 +3 4 2 +NULL NULL 4 diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 59133379d4431..aaba6998ee63c 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,7 +67,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -86,7 +86,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 07)--------TableScan: t2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7fee84f9bcd92..4a2d9e1d68641 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1135,8 +1135,8 @@ SELECT query IRR SELECT c8, - CUME_DIST() OVER(ORDER BY c9) as cd1, - CUME_DIST() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 + cume_dist() OVER(ORDER BY c9) as cd1, + cume_dist() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 FROM aggregate_test_100 ORDER BY c8 LIMIT 5 @@ -1376,16 +1376,16 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 +01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] +01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2208,7 +2208,7 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] 02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted] 05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] 06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 07)------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -2378,17 +2378,41 @@ SELECT c9, rn1 FROM (SELECT c9, # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. negative as following -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between current row and -1 following) from (select 1 a) x +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. negative as following +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between current row and -1 following) from (select 1 a) x + +# interval for rows +query I +select row_number() over (rows between '1' preceding and current row) from (select 1 a) x +---- +1 + +# interval for groups +query I +select row_number() over (order by a groups between '1' preceding and current row) from (select 1 a) x +---- +1 + # This test shows that ordering satisfy considers ordering equivalences, # and can simplify (reduce expression size) multi expression requirements during normalization # For the example below, requirement rn1 ASC, c9 DESC should be simplified to the rn1 ASC. @@ -2636,15 +2660,15 @@ EXPLAIN SELECT ---- logical_plan 01)Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII @@ -4894,3 +4918,131 @@ NULL a4 5 statement ok drop table t + +## test handle NULL and 0 value of nth_value +statement ok +CREATE TABLE t(v1 int, v2 int); + +statement ok +INSERT INTO t VALUES (1,1), (1,2),(1,3),(2,1),(2,2); + +query II +SELECT v1, NTH_VALUE(v2, null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, v2*null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query I +SELECT NTH_VALUE(tt0.v1, NULL) OVER (PARTITION BY tt0.v2 ORDER BY tt0.v1) FROM t AS tt0; +---- +NULL +NULL +NULL +NULL +NULL + +statement ok +DROP TABLE t; + +## end test handle NULL and 0 of NTH_VALUE + +## test handle NULL of lead + +statement ok +create table t1(v1 int); + +statement ok +insert into t1 values (1); + +query B +SELECT LEAD(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LEAD(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true + +statement ok +insert into t1 values (2); + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +NULL +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +NULL +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true +NULL + +statement ok +DROP TABLE t1; + +## end test handle NULL of lead diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 6f8f81401f3b6..b0aa6acf3c7c8 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" [lints] workspace = true @@ -41,7 +41,7 @@ object_store = { workspace = true } pbjson-types = "0.7" # TODO use workspace version prost = "0.13" -substrait = { version = "0.42", features = ["serde"] } +substrait = { version = "0.45", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 0b1c796553c0a..a6f7c033f9d0b 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -68,6 +68,7 @@ //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan //! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) //! # } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e6bfc67eda81c..2aaf8ec0aa06b 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -34,6 +34,7 @@ use datafusion::logical_expr::{ ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression_reference::ExprType; use url::Url; use crate::extensions::Extensions; @@ -41,20 +42,20 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, }; #[allow(deprecated)] use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{new_empty_array, AsArray}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::dataframe::DataFrame; -use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, @@ -68,16 +69,16 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ - IntervalDayToSecond, IntervalYearToMonth, UserDefined, + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -96,7 +97,7 @@ use substrait::proto::{ sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, NamedStruct, Plan, Rel, Type, }; -use substrait::proto::{FunctionArgument, SortField}; +use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone @@ -118,6 +119,7 @@ pub fn name_to_op(name: &str) -> Option { "multiply" => Some(Operator::Multiply), "divide" => Some(Operator::Divide), "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), "and" => Some(Operator::And), "or" => Some(Operator::Or), "is_distinct_from" => Some(Operator::IsDistinctFrom), @@ -196,6 +198,65 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( (accum_join_keys, nulls_equal_nulls, join_filter) } +async fn union_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + from_substrait_rel(ctx, &rels[0], extensions).await?, + )); + for input in &rels[1..] { + let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( ctx: &SessionContext, @@ -226,18 +287,19 @@ pub async fn from_substrait_plan( // Nothing to do if the schema is already equivalent return Ok(plan); } - match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), LogicalPlan::Aggregate(a) => { - let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); + let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) } } }, @@ -251,12 +313,86 @@ pub async fn from_substrait_plan( } } -/// parse projection -pub fn extract_projection( - t: LogicalPlan, - projection: &::core::option::Option, -) -> Result { - match projection { +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + ctx: &SessionContext, + extended_expr: &ExtendedExpression, +) -> Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(base_schema, &extensions), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = + from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + /*rename_self=*/ true, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + +pub fn apply_masking( + schema: DFSchema, + mask_expression: &::core::option::Option, +) -> Result { + match mask_expression { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { let column_indices: Vec = projection @@ -264,41 +400,23 @@ pub fn extract_projection( .iter() .map(|item| item.field as usize) .collect(); - match t { - LogicalPlan::TableScan(mut scan) => { - let fields = column_indices - .iter() - .map(|i| scan.projected_schema.qualified_field(*i)) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - scan.projection = Some(column_indices); - scan.projected_schema = DFSchemaRef::new( - DFSchema::new_with_metadata(fields, HashMap::new())?, - ); - Ok(LogicalPlan::TableScan(scan)) - } - LogicalPlan::Projection(projection) => { - // create another Projection around the Projection to handle the field masking - let fields: Vec = column_indices - .into_iter() - .map(|i| { - let (qualifier, field) = - projection.schema.qualified_field(i); - let column = - Column::new(qualifier.cloned(), field.name()); - Expr::Column(column) - }) - .collect(); - project(LogicalPlan::Projection(projection), fields) - } - _ => plan_err!("unexpected plan for table"), - } + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) } - _ => Ok(t), + None => Ok(schema), }, - _ => Ok(t), + None => Ok(schema), } } @@ -309,11 +427,11 @@ pub fn extract_projection( fn rename_expressions( exprs: impl IntoIterator, input_schema: &DFSchema, - new_schema: &DFSchema, + new_schema_fields: &[Arc], ) -> Result> { exprs .into_iter() - .zip(new_schema.fields()) + .zip(new_schema_fields) .map(|(old_expr, new_field)| { // Check if type (i.e. nested struct field names) match, use Cast to rename if needed let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { @@ -334,6 +452,68 @@ fn rename_expressions( .collect() } +fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names + rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name +) -> Result { + let name = if rename_self { + next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? + } else { + field.name().to_string() + }; + match field.data_type() { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(child_idx, f)| { + rename_field( + f.as_ref(), + dfs_names, + child_idx, + name_idx, + /*rename_self=*/ true, + ) + }) + .collect::>()?; + Ok(field + .to_owned() + .with_name(name) + .with_data_type(DataType::Struct(children))) + } + DataType::List(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::List(FieldRef::new(renamed_inner))) + .with_name(name)) + } + DataType::LargeList(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self= */ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) + .with_name(name)) + } + _ => Ok(field.to_owned().with_name(name)), + } +} + /// Produce a version of the given schema with names matching the given list of names. /// Substrait doesn't deal with column (incl. nested struct field) names within the schema, /// but it does give us the list of expected names at the end of the plan, so we use this @@ -342,59 +522,20 @@ fn make_renamed_schema( schema: &DFSchemaRef, dfs_names: &Vec, ) -> Result { - fn rename_inner_fields( - dtype: &DataType, - dfs_names: &Vec, - name_idx: &mut usize, - ) -> Result { - match dtype { - DataType::Struct(fields) => { - let fields = fields - .iter() - .map(|f| { - let name = next_struct_field_name(0, dfs_names, name_idx)?; - Ok((**f).to_owned().with_name(name).with_data_type( - rename_inner_fields(f.data_type(), dfs_names, name_idx)?, - )) - }) - .collect::>()?; - Ok(DataType::Struct(fields)) - } - DataType::List(inner) => Ok(DataType::List(FieldRef::new( - (**inner).to_owned().with_data_type(rename_inner_fields( - inner.data_type(), - dfs_names, - name_idx, - )?), - ))), - DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( - (**inner).to_owned().with_data_type(rename_inner_fields( - inner.data_type(), - dfs_names, - name_idx, - )?), - ))), - _ => Ok(dtype.to_owned()), - } - } - let mut name_idx = 0; let (qualifiers, fields): (_, Vec) = schema .iter() - .map(|(q, f)| { - let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; - Ok(( - q.cloned(), - (**f) - .to_owned() - .with_name(name) - .with_data_type(rename_inner_fields( - f.data_type(), - dfs_names, - &mut name_idx, - )?), - )) + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = rename_field( + f.as_ref(), + dfs_names, + field_idx, + &mut name_idx, + /*rename_self=*/ true, + )?; + Ok((q.cloned(), renamed_f)) }) .collect::>>()? .into_iter() @@ -414,6 +555,7 @@ fn make_renamed_schema( } /// Convert Substrait Rel to DataFusion DataFrame +#[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( ctx: &SessionContext, @@ -483,8 +625,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` - let count = if fetch.count as usize == usize::MAX { + // -1 means that ALL records should be returned + let count = if fetch.count == -1 { None } else { Some(fetch.count as usize) @@ -573,14 +715,27 @@ pub async fn from_substrait_rel( } _ => false, }; + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts( + ctx, + &f.sorts, + input.schema(), + extensions, + ) + .await?, + ) + } else { + None + }; + from_substrait_agg_func( ctx, f, input.schema(), extensions, filter, - // TODO: Add parsing of order_by also - None, + order_by, distinct, ) .await @@ -640,7 +795,17 @@ pub async fn from_substrait_rel( )? .build() } - None => plan_err!("JoinRel without join condition is not allowed"), + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + false, + )? + .build() + } } } Some(RelType::Cross(cross)) => { @@ -654,54 +819,61 @@ pub async fn from_substrait_rel( let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } - Some(RelType::Read(read)) => match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Named Table") - })?; - - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; + Some(RelType::Read(read)) => { + fn read_with_schema( + df: DataFrame, + schema: DFSchema, + projection: &Option, + ) -> Result { + ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; - let substrait_schema = - from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference.clone()); + let schema = apply_masking(schema, projection)?; - let t = ctx.table(table_reference.clone()).await?; - let t = ensure_schema_compatability(t, substrait_schema)?; - let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) + apply_projection(df, schema) } - Some(ReadType::VirtualTable(vt)) => { - let base_schema = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Virtual Table") - })?; - let schema = from_substrait_named_struct(base_schema, extensions)?; + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; + + let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; + + match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + let t = ctx.table(table_reference.clone()).await?; - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(schema), - })); + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); + + read_with_schema(t, substrait_schema, &read.projection) } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } - let values = vt + let values = vt .values .iter() .map(|row| { @@ -714,82 +886,108 @@ pub async fn from_substrait_rel( Ok(Expr::Literal(from_substrait_literal( lit, extensions, - &base_schema.names, + &named_struct.names, &mut name_idx, )?)) }) .collect::>()?; - if name_idx != base_schema.names.len() { + if name_idx != named_struct.names.len() { return substrait_err!( "Names list must match exactly to nested schema, but found {} uses for {} names", name_idx, - base_schema.names.len() + named_struct.names.len() ); } Ok(lits) }) .collect::>()?; - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = if name.starts_with("file://") + && !name.starts_with("file:///") + { name.replacen("file://", "file:///", 1) } else { name.to_string() }; - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference.clone()).await?; - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); + read_with_schema(t, substrait_schema, &read.projection) + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference).await?; - let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) } - _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), - }, + } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { - Ok(set_op) => match set_op { - set_rel::SetOp::UnionAll => { - if !set.inputs.is_empty() { - let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - )); - for input in &set.inputs[1..] { - union_builder = union_builder? - .union(from_substrait_rel(ctx, input, extensions).await?); + Ok(set_op) => { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set_op { + set_rel::SetOp::UnionAll => { + union_rels(&set.inputs, ctx, extensions, true).await } - union_builder?.build() - } else { - not_impl_err!("Union relation requires at least one input") + set_rel::SetOp::UnionDistinct => { + union_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionPrimary => { + LogicalPlanBuilder::intersect( + from_substrait_rel(ctx, &set.inputs[0], extensions) + .await?, + union_rels(&set.inputs[1..], ctx, extensions, true) + .await?, + false, + ) + } + set_rel::SetOp::IntersectionMultiset => { + intersect_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionMultisetAll => { + intersect_rels(&set.inputs, ctx, extensions, true).await + } + set_rel::SetOp::MinusPrimary => { + except_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::MinusPrimaryAll => { + except_rels(&set.inputs, ctx, extensions, true).await + } + _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } } - _ => not_impl_err!("Unsupported set operator: {set_op:?}"), - }, + } Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { @@ -885,30 +1083,61 @@ pub async fn from_substrait_rel( /// 1. All fields present in the Substrait schema are present in the DataFusion schema. The /// DataFusion schema may have MORE fields, but not the other way around. /// 2. All fields are compatible. See [`ensure_field_compatability`] for details -/// -/// This function returns a DataFrame with fields adjusted if necessary in the event that the -/// Substrait schema is a subset of the DataFusion schema. fn ensure_schema_compatability( - table: DataFrame, + table_schema: DFSchema, substrait_schema: DFSchema, -) -> Result { - let df_schema = table.schema().to_owned().strip_qualifiers(); - if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(table); - } - let selected_columns = substrait_schema +) -> Result<()> { + substrait_schema .strip_qualifiers() .fields() .iter() - .map(|substrait_field| { + .try_for_each(|substrait_field| { let df_field = - df_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatability(df_field, substrait_field)?; - Ok(col(format!("\"{}\"", df_field.name()))) + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field) }) - .collect::>()?; +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { + let df_schema = table.schema().to_owned(); + + let t = table.into_unoptimized_plan(); + + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(t); + } + + match t { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; - table.select(selected_columns) + let fields = column_indices + .iter() + .map(|i| df_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); + scan.projection = Some(column_indices); + + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), + } } /// Ensures that the given Substrait field is compatible with the given DataFusion field @@ -1617,9 +1846,14 @@ fn from_substrait_type( Ok(DataType::Interval(IntervalUnit::YearMonth)) } r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } r#type::Kind::UserDefined(u) => { if let Some(name) = extensions.types.get(&u.type_reference) { + #[allow(deprecated)] match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", @@ -1628,18 +1862,17 @@ fn from_substrait_type( ), } } else { - // Kept for backwards compatibility, new plans should include the extension instead #[allow(deprecated)] match u.type_reference { - // Kept for backwards compatibility, use IntervalYear instead + // Kept for backwards compatibility, producers should use IntervalYear instead INTERVAL_YEAR_MONTH_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } - // Kept for backwards compatibility, use IntervalDay instead + // Kept for backwards compatibility, producers should use IntervalDay instead INTERVAL_DAY_TIME_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::DayTime)) } - // Not supported yet by Substrait + // Kept for backwards compatibility, producers should use IntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } @@ -1681,14 +1914,14 @@ fn from_substrait_struct_type( } fn next_struct_field_name( - i: usize, + column_idx: usize, dfs_names: &[String], name_idx: &mut usize, ) -> Result { if dfs_names.is_empty() { // If names are not given, create dummy names // c0, c1, ... align with e.g. SqlToRel::create_named_struct - Ok(format!("c{i}")) + Ok(format!("c{column_idx}")) } else { let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { substrait_datafusion_err!("Named schema must contain names for all fields") @@ -1902,11 +2135,7 @@ fn from_substrait_literal( let s = d.scale.try_into().map_err(|e| { substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; - ScalarValue::Decimal128( - Some(std::primitive::i128::from_le_bytes(value)), - p, - s, - ) + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) } Some(LiteralType::List(l)) => { // Each element should start the name index from the same value, then we increase it @@ -2061,6 +2290,7 @@ fn from_substrait_literal( subseconds, precision_mode, })) => { + use interval_day_to_second::PrecisionMode; // DF only supports millisecond precision, so for any more granular type we lose precision let milliseconds = match precision_mode { Some(PrecisionMode::Microseconds(ms)) => ms / 1000, @@ -2085,6 +2315,35 @@ fn from_substrait_literal( Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { ScalarValue::new_interval_ym(*years, *months) } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed @@ -2115,6 +2374,8 @@ fn from_substrait_literal( if let Some(name) = extensions.types.get(&user_defined.type_reference) { match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { interval_month_day_nano(user_defined)? } @@ -2127,10 +2388,9 @@ fn from_substrait_literal( } } } else { - // Kept for backwards compatibility - new plans should include extension instead #[allow(deprecated)] match user_defined.type_reference { - // Kept for backwards compatibility, use IntervalYearToMonth instead + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead INTERVAL_YEAR_MONTH_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval year month value is empty"); @@ -2145,7 +2405,7 @@ fn from_substrait_literal( value_slice, ))) } - // Kept for backwards compatibility, use IntervalDayToSecond instead + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead INTERVAL_DAY_TIME_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval day time value is empty"); @@ -2165,6 +2425,7 @@ fn from_substrait_literal( milliseconds, })) } + // Kept for backwards compatibility, producers should useIntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_REF => { interval_month_day_nano(user_defined)? } @@ -2389,7 +2650,7 @@ impl BuiltinExprBuilder { match name { "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" => Some(Self { + | "is_not_unknown" | "negative" | "negate" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -2410,8 +2671,9 @@ impl BuiltinExprBuilder { "ilike" => { Self::build_like_expr(ctx, true, f, input_schema, extensions).await } - "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" - | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => { Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) .await } @@ -2440,7 +2702,7 @@ impl BuiltinExprBuilder { let expr = match fn_name { "not" => Expr::Not(arg), - "negative" => Expr::Negative(arg), + "negative" | "negate" => Expr::Negative(arg), "is_null" => Expr::IsNull(arg), "is_not_null" => Expr::IsNotNull(arg), "is_true" => Expr::IsTrue(arg), @@ -2513,3 +2775,52 @@ impl BuiltinExprBuilder { })) } } + +#[cfg(test)] +mod test { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use arrow_buffer::IntervalMonthDayNano; + use datafusion::error::Result; + use datafusion::scalar::ScalarValue; + use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, + IntervalYearToMonth, LiteralType, + }; + use substrait::proto::expression::Literal; + + #[test] + fn interval_compound_different_precision() -> Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + assert_eq!( + from_substrait_literal_without_names(&substrait, &Extensions::default())?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index fada827875b09..408885f70687f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. -use itertools::Itertools; +use datafusion::config::ConfigOptions; +use datafusion::optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion::optimizer::AnalyzerRule; use std::sync::Arc; +use substrait::proto::expression_reference::ExprType; -use arrow_buffer::ToByteSlice; -use datafusion::arrow::datatypes::IntervalUnit; +use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -36,10 +38,11 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, substrait_err, DFSchemaRef, ToDFSchema, @@ -55,15 +58,17 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, UserDefined, + IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; -use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon}; +use substrait::proto::{ + rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, +}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -101,10 +106,15 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result Result> { + let mut extensions = Extensions::default(); + + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = to_substrait_rex( + ctx, + expr, + schema, + /*col_ref_offset=*/ 0, + &mut extensions, + )?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema)?; + + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + /// Convert DataFusion LogicalPlan to Substrait Rel +#[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, ctx: &SessionContext, @@ -142,7 +203,7 @@ pub fn to_substrait_rel( }); let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema, extensions)?; + let base_schema = to_substrait_named_struct(&table_schema)?; Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { @@ -168,13 +229,14 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&e.schema, extensions)?), + base_schema: Some(to_substrait_named_struct(&e.schema)?), filter: None, best_effort_filter: None, projection: None, advanced_extension: None, read_type: Some(ReadType::VirtualTable(VirtualTable { values: vec![], + expressions: vec![], })), }))), })) @@ -206,12 +268,15 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&v.schema, extensions)?), + base_schema: Some(to_substrait_named_struct(&v.schema)?), filter: None, best_effort_filter: None, projection: None, advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { values })), + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), }))), })) } @@ -261,14 +326,19 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; - // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` - let limit_fetch = limit.fetch.unwrap_or(usize::MAX); + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!("Non-literal limit fetch"); + }; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!("Non-literal limit skip"); + }; Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, input: Some(input), - offset: limit.skip as i64, - count: limit_fetch as i64, + offset: skip as i64, + // use -1 to signal that ALL records should be returned + count: fetch.map(|f| f as i64).unwrap_or(-1), advanced_extension: None, }))), })) @@ -307,6 +377,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings, measures, advanced_extension: None, @@ -325,8 +396,10 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings: vec![Grouping { grouping_expressions: grouping, + expression_references: vec![], }], measures: vec![], advanced_extension: None, @@ -403,23 +476,6 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::CrossJoin(cross_join) => { - let CrossJoin { - left, - right, - schema: _, - } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extensions)?; - let right = to_substrait_rel(right.as_ref(), ctx, extensions)?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Cross(Box::new(CrossRel { - common: None, - left: Some(left), - right: Some(right), - advanced_extension: None, - }))), - })) - } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait @@ -435,7 +491,7 @@ pub fn to_substrait_rel( .map(|ptr| *ptr) .collect(); Ok(Box::new(Rel { - rel_type: Some(substrait::proto::rel::RelType::Set(SetRel { + rel_type: Some(RelType::Set(SetRel { common: None, inputs: input_rels, op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL @@ -580,55 +636,45 @@ fn create_project_remapping(expr_count: usize, input_field_count: usize) -> Emit Emit(rel_common::Emit { output_mapping }) } -fn to_substrait_named_struct( - schema: &DFSchemaRef, - extensions: &mut Extensions, -) -> Result { - // Substrait wants a list of all field names, including nested fields from structs, - // also from within e.g. lists and maps. However, it does not want the list and map field names - // themselves - only proper structs fields are considered to have useful names. - fn names_dfs(dtype: &DataType) -> Result> { - match dtype { - DataType::Struct(fields) => { - let mut names = Vec::new(); - for field in fields { - names.push(field.name().to_string()); - names.extend(names_dfs(field.data_type())?); - } - Ok(names) +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; } - DataType::List(l) => names_dfs(l.data_type()), - DataType::LargeList(l) => names_dfs(l.data_type()), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_names = - names_dfs(key_and_value.first().unwrap().data_type())?; - let value_names = - names_dfs(key_and_value.last().unwrap().data_type())?; - Ok([key_names, value_names].concat()) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - _ => Ok(Vec::new()), + Ok(()) } - } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} - let names = schema - .fields() - .iter() - .map(|f| { - let mut names = vec![f.name().to_string()]; - names.extend(names_dfs(f.data_type())?); - Ok(names) - }) - .flatten_ok() - .collect::>()?; +fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; + } let field_types = r#type::Struct { types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable(), extensions)) + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, @@ -695,7 +741,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Minus => "subtract", Operator::Multiply => "multiply", Operator::Divide => "divide", - Operator::Modulo => "mod", + Operator::Modulo => "modulus", Operator::And => "and", Operator::Or => "or", Operator::IsDistinctFrom => "is_distinct_from", @@ -719,6 +765,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } +#[allow(deprecated)] pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], @@ -731,6 +778,7 @@ pub fn parse_flat_grouping_exprs( .collect::>>()?; Ok(Grouping { grouping_expressions, + expression_references: vec![], }) } @@ -1099,7 +1147,7 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true, extensions)?), + r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( ctx, expr, @@ -1293,7 +1341,7 @@ pub fn to_substrait_rex( ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, - "negative", + "negate", arg, schema, col_ref_offset, @@ -1305,11 +1353,7 @@ pub fn to_substrait_rex( } } -fn to_substrait_type( - dt: &DataType, - nullable: bool, - extensions: &mut Extensions, -) -> Result { +fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { @@ -1438,16 +1482,14 @@ fn to_substrait_type( })), }), IntervalUnit::MonthDayNano => { - // Substrait doesn't currently support this type, so we represent it as a UDT Ok(substrait::proto::Type { - kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { - type_reference: extensions.register_type( - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - ), - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - type_parameters: vec![], - })), + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), }) } } @@ -1496,8 +1538,7 @@ fn to_substrait_type( })), }), DataType::List(inner) => { - let inner_type = - to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1507,8 +1548,7 @@ fn to_substrait_type( }) } DataType::LargeList(inner) => { - let inner_type = - to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1522,12 +1562,10 @@ fn to_substrait_type( let key_type = to_substrait_type( key_and_value[0].data_type(), key_and_value[0].is_nullable(), - extensions, )?; let value_type = to_substrait_type( key_and_value[1].data_type(), key_and_value[1].is_nullable(), - extensions, )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { @@ -1543,9 +1581,7 @@ fn to_substrait_type( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| { - to_substrait_type(field.data_type(), field.is_nullable(), extensions) - }) + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -1667,98 +1703,38 @@ fn make_substrait_like_expr( } } +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} + fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { match bound { WindowFrameBound::CurrentRow => Bound { kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), }, - WindowFrameBound::Preceding(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, - WindowFrameBound::Following(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, @@ -1792,7 +1768,6 @@ fn to_substrait_literal( literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, - extensions, )?)), }); } @@ -1901,23 +1876,21 @@ fn to_substrait_literal( }), DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::IntervalMonthDayNano(Some(i)) => { - // IntervalMonthDayNano is internally represented as a 128-bit integer, containing - // months (32bit), days (32bit), and nanoseconds (64bit) - let bytes = i.to_byte_slice(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: extensions - .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), - type_parameters: vec![], - val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - value: bytes.to_vec().into(), - })), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, }), - DEFAULT_TYPE_VARIATION_REF, - ) - } + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::IntervalDayTime(Some(i)) => ( LiteralType::IntervalDayToSecond(IntervalDayToSecond { days: i.days, @@ -1973,7 +1946,7 @@ fn to_substrait_literal( ), ScalarValue::Map(m) => { let map = if m.is_empty() || m.value(0).is_empty() { - let mt = to_substrait_type(m.data_type(), m.is_nullable(), extensions)?; + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; let mt = match mt { substrait::proto::Type { kind: Some(r#type::Kind::Map(mt)), @@ -2058,11 +2031,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let lt = match to_substrait_type( - array.data_type(), - array.is_nullable(), - extensions, - )? { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -2178,15 +2147,16 @@ fn substrait_field_ref(index: usize) -> Result { mod test { use super::*; use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, from_substrait_type_without_names, + from_substrait_extended_expr, from_substrait_literal_without_names, + from_substrait_named_struct, from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ GenericListArray, Int64Builder, MapBuilder, StringBuilder, }; - use datafusion::arrow::datatypes::Field; + use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use std::collections::HashMap; + use datafusion::common::DFSchema; #[test] fn round_trip_literals() -> Result<()> { @@ -2317,39 +2287,6 @@ mod test { Ok(()) } - #[test] - fn custom_type_literal_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( - 17, 25, 1234567890, - ))); - let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&substrait_literal, &extensions)?; - assert_eq!(scalar, roundtrip_scalar); - - assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } - ); - - // Check we fail if we don't propagate extensions - assert!(from_substrait_literal_without_names( - &substrait_literal, - &Extensions::default() - ) - .is_err()); - Ok(()) - } - #[test] fn round_trip_types() -> Result<()> { round_trip_type(DataType::Boolean)?; @@ -2421,44 +2358,109 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); - let mut extensions = Extensions::default(); - // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true, &mut extensions)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; + let substrait = to_substrait_type(&dt, true)?; + let roundtrip_dt = + from_substrait_type_without_names(&substrait, &Extensions::default())?; assert_eq!(dt, roundtrip_dt); Ok(()) } #[test] - fn custom_type_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let dt = DataType::Interval(IntervalUnit::MonthDayNano); + fn named_struct_names() -> Result<()> { + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); - let substrait = to_substrait_type(&dt, true, &mut extensions)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; - assert_eq!(dt, roundtrip_dt); + let named_struct = to_substrait_named_struct(&schema)?; + // Struct field names should be flattened DFS style + // List field names should be omitted assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } + named_struct.names, + vec!["int", "struct", "inner", "trailer"] ); - // Check we fail if we don't propagate extensions - assert!( - from_substrait_type_without_names(&substrait, &Extensions::default()) - .is_err() - ); + let roundtrip_schema = + from_substrait_named_struct(&named_struct, &Extensions::default())?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } + + #[tokio::test] + async fn extended_expressions() -> Result<()> { + let ctx = SessionContext::new(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &ctx, + )?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); Ok(()) } + + #[tokio::test] + async fn invalid_extended_expression() { + let ctx = SessionContext::new(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index a3e76389d5104..58774db424da1 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -96,7 +96,7 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano #[deprecated( since = "41.0.0", - note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" + note = "Use Substrait `IntervalCompund` type instead" )] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; @@ -104,4 +104,8 @@ pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano +#[deprecated( + since = "43.0.0", + note = "Use Substrait `IntervalCompund` type instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index b1cc763050311..bc38ef82977f3 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -55,7 +55,7 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\ \n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\ \n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: LINEITEM" ); Ok(()) } @@ -73,22 +73,22 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST\ \n Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"EUROPE\")\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]" + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PART\ + \n TableScan: SUPPLIER\ + \n TableScan: PARTSUPP\ + \n TableScan: NATION\ + \n TableScan: REGION" ); Ok(()) } @@ -105,11 +105,11 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" ); Ok(()) } @@ -126,8 +126,8 @@ mod tests { \n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS ()\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: ORDERS" ); Ok(()) } @@ -142,17 +142,17 @@ mod tests { \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"ASIA\") AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]" + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION" ); Ok(()) } @@ -165,7 +165,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT) AS REVENUE]]\ \n Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT\ \n Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32) AND LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <= Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2))\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: LINEITEM" ); Ok(()) } @@ -206,13 +206,13 @@ mod tests { \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-10-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8(\"R\") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: NATION" ); Ok(()) } @@ -230,19 +230,19 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ \n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION" ); Ok(()) } @@ -257,9 +257,9 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END)]]\ \n Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\ \n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n Cross Join: \ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" ); Ok(()) } @@ -277,8 +277,8 @@ mod tests { \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\ \n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\ \n Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8(\"%special%requests%\") AS Utf8)\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" ); Ok(()) } @@ -292,9 +292,9 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8(\"PROMO%\") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-10-01\") AS Date32)\ - \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: PART" ); Ok(()) } @@ -320,10 +320,10 @@ mod tests { \n Subquery:\ \n Projection: SUPPLIER.S_SUPPKEY\ \n Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n TableScan: SUPPLIER\ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: PART" ); Ok(()) } @@ -352,12 +352,12 @@ mod tests { \n Filter: sum(LINEITEM.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2))\ \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: LINEITEM\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" ); Ok(()) } @@ -369,9 +369,9 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]]\ \n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#23\") AND (PART.P_CONTAINER = CAST(Utf8(\"MED BAG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PKG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PACK\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#34\") AND (PART.P_CONTAINER = CAST(Utf8(\"LG CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\")\ - \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: PART" ); Ok(()) } @@ -390,17 +390,17 @@ mod tests { \n Subquery:\ \n Projection: PART.P_PARTKEY\ \n Filter: PART.P_NAME LIKE CAST(Utf8(\"forest%\") AS Utf8)\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\ + \n TableScan: PART\ \n Subquery:\ \n Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY)\ \n Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_QUANTITY\ \n Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n CrossJoin:\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: PARTSUPP\ + \n Cross Join: \ + \n TableScan: SUPPLIER\ + \n TableScan: NATION" ); Ok(()) } @@ -418,17 +418,17 @@ mod tests { \n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ + \n TableScan: LINEITEM\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: LINEITEM\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: SUPPLIER\ + \n TableScan: LINEITEM\ + \n TableScan: ORDERS\ + \n TableScan: NATION" ); Ok(()) } @@ -447,11 +447,11 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[avg(CUSTOMER.C_ACCTBAL)]]\ \n Projection: CUSTOMER.C_ACCTBAL\ \n Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8))\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ + \n TableScan: CUSTOMER\ \n Subquery:\ \n Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]" + \n TableScan: ORDERS\ + \n TableScan: CUSTOMER" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index 5806b55d84c46..b136b0af19c29 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -37,7 +37,7 @@ mod tests { plan_str, "Projection: nation.n_name\ \n Filter: contains(nation.n_name, Utf8(\"IA\"))\ - \n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]" + \n TableScan: nation" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 6794b32838a83..f4e34af35d78e 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -43,7 +43,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: NOT DATA.D AS EXPR$0\ - \n TableScan: DATA projection=[D]" + \n TableScan: DATA" ); Ok(()) } @@ -69,7 +69,7 @@ mod tests { format!("{}", plan), "Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: DATA projection=[D, PART, ORD]" + \n TableScan: DATA" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ea85092f7a6cb..04530dd34d4bf 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::test::read_json; use datafusion::arrow::array::ArrayRef; use datafusion::physical_plan::Accumulator; use datafusion::scalar::ScalarValue; @@ -38,13 +39,11 @@ use std::hash::Hash; use std::sync::Arc; use datafusion::execution::session_state::SessionStateBuilder; -use substrait::proto::extensions::simple_extension_declaration::{ - ExtensionType, MappingType, -}; -use substrait::proto::extensions::SimpleExtensionDeclaration; +use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; +#[derive(Debug)] struct MockSerializerRegistry; impl SerializerRegistry for MockSerializerRegistry { @@ -67,8 +66,7 @@ impl SerializerRegistry for MockSerializerRegistry { &self, name: &str, bytes: &[u8], - ) -> Result> - { + ) -> Result> { if name == "MockUserDefinedLogicalPlan" { MockUserDefinedLogicalPlan::deserialize(bytes) } else { @@ -148,6 +146,10 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { fn dyn_ord(&self, _: &dyn UserDefinedLogicalNode) -> Option { unimplemented!() } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } impl MockUserDefinedLogicalPlan { @@ -178,7 +180,13 @@ async fn simple_select() -> Result<()> { #[tokio::test] async fn wildcard_select() -> Result<()> { - roundtrip("SELECT * FROM data").await + assert_expected_plan_unoptimized( + "SELECT * FROM data", + "Projection: data.a, data.b, data.c, data.d, data.e, data.f\ + \n TableScan: data", + true, + ) + .await } #[tokio::test] @@ -219,23 +227,6 @@ async fn select_with_reused_functions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn roundtrip_udt_extensions() -> Result<()> { - let ctx = create_context().await?; - let proto = - roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) - .await?; - let expected_type = SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionType(ExtensionType { - extension_uri_reference: u32::MAX, - type_anchor: 0, - name: "interval-month-day-nano".to_string(), - })), - }; - assert_eq!(proto.extensions, vec![expected_type]); - Ok(()) -} - #[tokio::test] async fn select_with_filter_date() -> Result<()> { roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await @@ -289,8 +280,9 @@ async fn aggregate_grouping_sets() -> Result<()> { async fn aggregate_grouping_rollup() -> Result<()> { assert_expected_plan( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ - \n TableScan: data projection=[a, b, c, e]", + "Projection: data.a, data.c, data.e, avg(data.b)\ + \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ + \n TableScan: data projection=[a, b, c, e]", true ).await } @@ -462,16 +454,14 @@ async fn roundtrip_inlist_5() -> Result<()> { // using assert_expected_plan here as a workaround assert_expected_plan( "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]\ - \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]", + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\ + \n Projection: data.a, data.f, Boolean(true)\ + \n Left Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a, Boolean(true)\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", true).await } @@ -583,6 +573,11 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_modulus() -> Result<()> { + roundtrip("SELECT a%3 from data").await +} + #[tokio::test] async fn roundtrip_not() -> Result<()> { roundtrip("SELECT * FROM data WHERE NOT d").await @@ -657,6 +652,109 @@ async fn simple_intersect() -> Result<()> { .await } +#[tokio::test] +async fn aggregate_wo_projection_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_no_project.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await +} + +#[tokio::test] +async fn aggregate_wo_projection_sorted_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await +} + +#[tokio::test] +async fn simple_intersect_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/minus_primary_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn union_distinct_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json"); + + assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await +} + #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. @@ -900,7 +998,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -920,8 +1018,9 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); + roundtrip_with_ctx("select dummy_agg(a) from data", ctx.clone()).await?; + roundtrip_with_ctx("select dummy_agg(a order by a) from data", ctx.clone()).await?; - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?; Ok(()) } @@ -1078,6 +1177,32 @@ async fn verify_post_join_filter_value(proto: Box) -> Result<()> { Ok(()) } +async fn assert_expected_plan_unoptimized( + sql: &str, + expected_plan_str: &str, + assert_schema: bool, +) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_unoptimized_plan(); + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + println!("{plan}"); + println!("{plan2}"); + + println!("{proto:?}"); + + if assert_schema { + assert_eq!(plan.schema(), plan2.schema()); + } + + let plan2str = format!("{plan2}"); + assert_eq!(expected_plan_str, &plan2str); + + Ok(()) +} + async fn assert_expected_plan( sql: &str, expected_plan_str: &str, @@ -1105,6 +1230,38 @@ async fn assert_expected_plan( Ok(()) } +async fn assert_expected_plan_substrait( + substrait_plan: Plan, + expected_plan_str: &str, +) -> Result<()> { + let ctx = create_context().await?; + + let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + + let plan = ctx.state().optimize(&plan)?; + + let planstr = format!("{plan}"); + assert_eq!(planstr, expected_plan_str); + + Ok(()) +} + +async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { + let ctx = create_context().await?; + + let expected = ctx.sql(sql).await?.into_optimized_plan()?; + + let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + + let plan = ctx.state().optimize(&plan)?; + + let planstr = format!("{plan}"); + let expectedstr = format!("{expected}"); + assert_eq!(planstr, expectedstr); + + Ok(()) +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index da0898d222c42..54d55d1b6f10e 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -20,13 +20,12 @@ mod tests { use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use datafusion_substrait::logical_plan::producer; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; use datafusion_substrait::serializer; use datafusion::error::Result; use datafusion::prelude::*; - use datafusion_substrait::logical_plan::producer::to_substrait_plan; use std::fs; use substrait::proto::plan_rel::RelType; use substrait::proto::rel_common::{Emit, EmitKind}; @@ -61,7 +60,7 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = producer::to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx); assert!(convert_result.is_ok()); Ok(()) @@ -117,8 +116,8 @@ mod tests { let datafusion_plan = df.into_optimized_plan()?; assert_eq!( format!("{}", datafusion_plan), - "Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + "Projection: data.b, rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: data projection=[a, b, c]", ); diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index cb1fb67fc0442..5ae586afe56f8 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -70,7 +70,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n TableScan: DATA projection=[a, b]" + \n TableScan: DATA" ); Ok(()) } @@ -91,8 +91,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b\ - \n TableScan: DATA projection=[b, a]" + \n TableScan: DATA projection=[a, b]" ); Ok(()) } @@ -102,12 +101,12 @@ mod tests { let proto_plan = read_json( "tests/testdata/test_plans/simple_select_with_mask.substrait.json", ); - // the DataFusion schema { b, a, c, d } contains the Substrait schema { a, b, c } + // the DataFusion schema { d, a, c, b } contains the Substrait schema { a, b, c } let df_schema = vec![ - ("b", DataType::Int32, true), + ("d", DataType::Int32, true), ("a", DataType::Int32, false), ("c", DataType::Int32, false), - ("d", DataType::Int32, false), + ("b", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; let plan = from_substrait_plan(&ctx, &proto_plan).await?; @@ -115,9 +114,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b, DATA.c\ - \n TableScan: DATA projection=[b, a, c]" + \n TableScan: DATA projection=[a, b]" ); Ok(()) } diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json new file mode 100644 index 0000000000000..ed8675b968269 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json @@ -0,0 +1,97 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json new file mode 100644 index 0000000000000..d5170223cd65b --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json new file mode 100644 index 0000000000000..b9a2e4ad14038 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json new file mode 100644 index 0000000000000..8ff69bd82c3a7 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json new file mode 100644 index 0000000000000..56daf6ed46f46 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json new file mode 100644 index 0000000000000..229dd7251705d --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json new file mode 100644 index 0000000000000..33b0e2ab8c801 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json new file mode 100644 index 0000000000000..229f78ab5bf6b --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json new file mode 100644 index 0000000000000..e8b02749660dd --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_DISTINCT" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 9f63b74ef0fc7..00cbfb0c412cf 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -147,6 +147,7 @@ pub mod test { Ok(()) } + #[allow(deprecated)] fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> { let rel_type = rel .rel_type diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 46e157aecfd9c..2440244d08c33 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -60,4 +60,5 @@ wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.40" [dev-dependencies] -wasm-bindgen-test = "0.3" +tokio = { workspace = true } +wasm-bindgen-test = "0.3.44" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 0af07e176f0f3..37512e8278a74 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -1095,9 +1095,9 @@ } }, "node_modules/cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true, "engines": { "node": ">= 0.6" @@ -1459,9 +1459,9 @@ } }, "node_modules/express": { - "version": "4.21.0", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.0.tgz", - "integrity": "sha512-VqcNGcj/Id5ZT1LZ/cfihi3ttTn+NJmkli2eZADigjq29qTlWi/hAQ43t/VLPq8+UX06FCEx3ByOYet6ZFblng==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "dependencies": { "accepts": "~1.3.8", @@ -1469,7 +1469,7 @@ "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -5247,9 +5247,9 @@ "dev": true }, "cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true }, "cookie-signature": { @@ -5524,9 +5524,9 @@ } }, "express": { - "version": "4.21.0", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.0.tgz", - "integrity": "sha512-VqcNGcj/Id5ZT1LZ/cfihi3ttTn+NJmkli2eZADigjq29qTlWi/hAQ43t/VLPq8+UX06FCEx3ByOYet6ZFblng==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "requires": { "accepts": "~1.3.8", @@ -5534,7 +5534,7 @@ "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 0f24449cbed3c..085064d16d947 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -87,13 +87,14 @@ mod test { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - #[wasm_bindgen_test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] fn datafusion_test() { basic_exprs(); basic_parse(); } - #[wasm_bindgen_test] + #[wasm_bindgen_test(unsupported = tokio::test)] async fn basic_execute() { let sql = "SELECT 2 + 2;"; diff --git a/dev/changelog/42.1.0.md b/dev/changelog/42.1.0.md new file mode 100644 index 0000000000000..cf4f911150acb --- /dev/null +++ b/dev/changelog/42.1.0.md @@ -0,0 +1,42 @@ + + +# Apache DataFusion 42.1.0 Changelog + +This release consists of 5 commits from 4 contributors. See credits at the end of this changelog for more information. + +**Other:** + +- Backport update to arrow 53.1.0 on branch-42 [#12977](https://github.com/apache/datafusion/pull/12977) (alamb) +- Backport "Provide field and schema metadata missing on cross joins, and union with null fields" (#12729) [#12974](https://github.com/apache/datafusion/pull/12974) (matthewmturner) +- Backport "physical-plan: Cast nested group values back to dictionary if necessary" (#12586) [#12976](https://github.com/apache/datafusion/pull/12976) (matthewmturner) +- backport-to-DF-42: Provide field and schema metadata missing on distinct aggregations [#12975](https://github.com/apache/datafusion/pull/12975) (Xuanwo) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 2 Matthew Turner + 1 Andrew Lamb + 1 Andy Grove + 1 Xuanwo +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/release/README.md b/dev/release/README.md index bd9c0621fdbc0..0e0daa9d6c407 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -260,19 +260,22 @@ Verify that the Cargo.toml in the tarball contains the correct version ```shell (cd datafusion/common && cargo publish) +(cd datafusion/expr-common && cargo publish) +(cd datafusion/physical-expr-common && cargo publish) +(cd datafusion/functions-aggregate-common && cargo publish) (cd datafusion/expr && cargo publish) (cd datafusion/execution && cargo publish) -(cd datafusion/physical-expr-common && cargo publish) -(cd datafusion/functions-aggregate && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/functions && cargo publish) +(cd datafusion/functions-aggregate && cargo publish) +(cd datafusion/functions-window && cargo publish) (cd datafusion/functions-nested && cargo publish) (cd datafusion/sql && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) -(cd datafusion/catalog && cargo publish) (cd datafusion/physical-plan && cargo publish) (cd datafusion/physical-optimizer && cargo publish) +(cd datafusion/catalog && cargo publish) (cd datafusion/core && cargo publish) (cd datafusion/proto-common && cargo publish) (cd datafusion/proto && cargo publish) diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 836ba6772eacd..585cb77839f98 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -24,7 +24,7 @@ SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "${SOURCE_DIR}/../" && pwd TARGET_FILE="docs/source/user-guide/configs.md" -PRINT_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" +PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" @@ -67,8 +67,8 @@ Environment variables are read during `SessionConfig` initialisation so they mus EOF -echo "Running CLI and inserting docs table" -$PRINT_DOCS_COMMAND >> "$TARGET_FILE" +echo "Running CLI and inserting config docs table" +$PRINT_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" echo "Running prettier" npx prettier@2.3.2 --write "$TARGET_FILE" diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh new file mode 100755 index 0000000000000..13bc22afcc135 --- /dev/null +++ b/dev/update_function_docs.sh @@ -0,0 +1,299 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +set -e + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SOURCE_DIR}/../" && pwd + + +TARGET_FILE="docs/source/user-guide/sql/aggregate_functions_new.md" +PRINT_AGGREGATE_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- aggregate" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + +# Aggregate Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Aggregate Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +Aggregate functions operate on a set of values to compute a single result. +EOF + +echo "Running CLI and inserting aggregate function docs table" +$PRINT_AGGREGATE_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" + +TARGET_FILE="docs/source/user-guide/sql/scalar_functions_new.md" +PRINT_SCALAR_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- scalar" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + +# Scalar Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +EOF + +echo "Running CLI and inserting scalar function docs table" +$PRINT_SCALAR_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" + +TARGET_FILE="docs/source/user-guide/sql/window_functions_new.md" +PRINT_WINDOW_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- window" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + + +# Window Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result + +Here is an example that shows how to compare each employee's salary with the average salary in his or her department: + +```sql +SELECT depname, empno, salary, avg(salary) OVER (PARTITION BY depname) FROM empsalary; + ++-----------+-------+--------+-------------------+ +| depname | empno | salary | avg | ++-----------+-------+--------+-------------------+ +| personnel | 2 | 3900 | 3700.0 | +| personnel | 5 | 3500 | 3700.0 | +| develop | 8 | 6000 | 5020.0 | +| develop | 10 | 5200 | 5020.0 | +| develop | 11 | 5200 | 5020.0 | +| develop | 9 | 4500 | 5020.0 | +| develop | 7 | 4200 | 5020.0 | +| sales | 1 | 5000 | 4866.666666666667 | +| sales | 4 | 4800 | 4866.666666666667 | +| sales | 3 | 4800 | 4866.666666666667 | ++-----------+-------+--------+-------------------+ +``` + +A window function call always contains an OVER clause directly following the window function's name and argument(s). This is what syntactically distinguishes it from a normal function or non-window aggregate. The OVER clause determines exactly how the rows of the query are split up for processing by the window function. The PARTITION BY clause within OVER divides the rows into groups, or partitions, that share the same values of the PARTITION BY expression(s). For each row, the window function is computed across the rows that fall into the same partition as the current row. The previous example showed how to count the average of a column per partition. + +You can also control the order in which rows are processed by window functions using ORDER BY within OVER. (The window ORDER BY does not even have to match the order in which the rows are output.) Here is an example: + +```sql +SELECT depname, empno, salary, + rank() OVER (PARTITION BY depname ORDER BY salary DESC) +FROM empsalary; + ++-----------+-------+--------+--------+ +| depname | empno | salary | rank | ++-----------+-------+--------+--------+ +| personnel | 2 | 3900 | 1 | +| develop | 8 | 6000 | 1 | +| develop | 10 | 5200 | 2 | +| develop | 11 | 5200 | 2 | +| develop | 9 | 4500 | 4 | +| develop | 7 | 4200 | 5 | +| sales | 1 | 5000 | 1 | +| sales | 4 | 4800 | 2 | +| personnel | 5 | 3500 | 2 | +| sales | 3 | 4800 | 2 | ++-----------+-------+--------+--------+ +``` + +There is another important concept associated with window functions: for each row, there is a set of rows within its partition called its window frame. Some window functions act only on the rows of the window frame, rather than of the whole partition. Here is an example of using window frames in queries: + +```sql +SELECT depname, empno, salary, + avg(salary) OVER(ORDER BY salary ASC ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS avg, + min(salary) OVER(ORDER BY empno ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cum_min +FROM empsalary +ORDER BY empno ASC; + ++-----------+-------+--------+--------------------+---------+ +| depname | empno | salary | avg | cum_min | ++-----------+-------+--------+--------------------+---------+ +| sales | 1 | 5000 | 5000.0 | 5000 | +| personnel | 2 | 3900 | 3866.6666666666665 | 3900 | +| sales | 3 | 4800 | 4700.0 | 3900 | +| sales | 4 | 4800 | 4866.666666666667 | 3900 | +| personnel | 5 | 3500 | 3700.0 | 3500 | +| develop | 7 | 4200 | 4200.0 | 3500 | +| develop | 8 | 6000 | 5600.0 | 3500 | +| develop | 9 | 4500 | 4500.0 | 3500 | +| develop | 10 | 5200 | 5133.333333333333 | 3500 | +| develop | 11 | 5200 | 5466.666666666667 | 3500 | ++-----------+-------+--------+--------------------+---------+ +``` + +When a query involves multiple window functions, it is possible to write out each one with a separate OVER clause, but this is duplicative and error-prone if the same windowing behavior is wanted for several functions. Instead, each windowing behavior can be named in a WINDOW clause and then referenced in OVER. For example: + +```sql +SELECT sum(salary) OVER w, avg(salary) OVER w +FROM empsalary +WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); +``` + +## Syntax + +The syntax for the OVER-clause is + +``` +function([expr]) + OVER( + [PARTITION BY expr[, …]] + [ORDER BY expr [ ASC | DESC ][, …]] + [ frame_clause ] + ) +``` + +where **frame_clause** is one of: + +``` + { RANGE | ROWS | GROUPS } frame_start + { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end +``` + +and **frame_start** and **frame_end** can be one of + +```sql +UNBOUNDED PRECEDING +offset PRECEDING +CURRENT ROW +offset FOLLOWING +UNBOUNDED FOLLOWING +``` + +where **offset** is an non-negative integer. + +RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). + +## Aggregate functions + +All [aggregate functions](aggregate_functions.md) can be used as window functions. + +EOF + +echo "Running CLI and inserting window function docs table" +$PRINT_WINDOW_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 79a9298798336..4645fe5c8804b 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -116,6 +116,20 @@ If you are concerned that a larger design will be lost in a string of small PRs, Note all commits in a PR are squashed when merged to the `main` branch so there is one commit per PR after merge. +## Conventional Commits & Labeling PRs + +We generate change logs for each release using an automated process that will categorize PRs based on the title +and/or the GitHub labels attached to the PR. + +We follow the [Conventional Commits] specification to categorize PRs based on the title. This most often simply means +looking for titles starting with prefixes such as `fix:`, `feat:`, `docs:`, or `chore:`. We do not enforce this +convention but encourage its use if you want your PR to feature in the correct section of the changelog. + +The change log generator will also look at GitHub labels such as `bug`, `enhancement`, or `api change`, and labels +do take priority over the conventional commit approach, allowing maintainers to re-categorize PRs after they have been merged. + +[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/ + # Reviewing Pull Requests Some helpful links: diff --git a/docs/source/index.rst b/docs/source/index.rst index 32a5dce323f23..9008950d3dd69 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,10 +34,21 @@ Apache DataFusion DataFusion is an extensible query engine written in `Rust `_ that -uses `Apache Arrow `_ as its in-memory format. DataFusion's target users are -developers building fast and feature rich database and analytic systems, +uses `Apache Arrow `_ as its in-memory format. + +The documentation on this site is for the `core DataFusion project `_, which contains +libraries and binaries for developers building fast and feature rich database and analytic systems, customized to particular workloads. See `use cases `_ for examples. +The following related subprojects target end users and have separate documentation. + +- `DataFusion Python `_ offers a Python interface for SQL and DataFrame + queries. +- `DataFusion Ray `_ provides a distributed version of DataFusion + that scales out on `Ray `_ clusters. +- `DataFusion Comet `_ is an accelerator for Apache Spark based on + DataFusion. + "Out of the box," DataFusion offers `SQL `_ and `Dataframe `_ APIs, excellent `performance `_, built-in support for CSV, Parquet, JSON, and Avro, @@ -119,6 +130,7 @@ To get started, see library-user-guide/extending-operators library-user-guide/profiling library-user-guide/query-optimizer + library-user-guide/api-health .. _toc.contributor-guide: .. toctree:: diff --git a/docs/source/library-user-guide/api-health.md b/docs/source/library-user-guide/api-health.md new file mode 100644 index 0000000000000..943a370e81723 --- /dev/null +++ b/docs/source/library-user-guide/api-health.md @@ -0,0 +1,37 @@ + + +# API health policy + +To maintain API health, developers must track and properly deprecate outdated methods. +When deprecating a method: + +- clearly mark the API as deprecated and specify the exact DataFusion version in which it was deprecated. +- concisely describe the preferred API, if relevant + +API deprecation example: + +```rust + #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self +``` + +Deprecated methods will remain in the codebase for a period of 6 major versions or 6 months, whichever is longer, to provide users ample time to transition away from them. + +Please refer to [DataFusion releases](https://crates.io/crates/datafusion/versions) to plan ahead API migration diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f34d148f092f3..91a2e8b4389a8 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -57,6 +57,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | | datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | | datafusion.execution.parquet.schema_force_view_types | false | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | | datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | | datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | | datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | @@ -66,7 +67,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 42.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 42.1.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | @@ -91,6 +92,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | | datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | | datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index c8f0ffbec701e..ababb001f5c5e 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -69,7 +69,7 @@ value ::: :::{note} -Since `&&` and `||` are existed as logical operators in Rust, but those are not overloadable and not works with expression API. +Since `&&` and `||` are logical operators in Rust and cannot be overloaded these are not available in the expression API. ::: ## Bitwise Expressions @@ -151,7 +151,7 @@ but these operators always return a `bool` which makes them not work with the ex | trunc(x) | truncate toward zero | :::{note} -Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g +Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g. ```sql select log(-1), log(0), sqrt(-1); diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 8f8983061eb69..7c975055d152f 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine +- [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) @@ -124,7 +125,6 @@ Here are some active projects using DataFusion: Here are some less active projects that used DataFusion: - [bdt](https://github.com/datafusion-contrib/bdt) Boring Data Tool -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core - [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion - [Flock](https://github.com/flock-lab/flock) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 1c214084b3faf..77f527c92cdae 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -19,595 +19,4 @@ # Aggregate Functions -Aggregate functions operate on a set of values to compute a single result. - -## General - -- [avg](#avg) -- [bit_and](#bit_and) -- [bit_or](#bit_or) -- [bit_xor](#bit_xor) -- [bool_and](#bool_and) -- [bool_or](#bool_or) -- [count](#count) -- [max](#max) -- [mean](#mean) -- [median](#median) -- [min](#min) -- [sum](#sum) -- [array_agg](#array_agg) -- [first_value](#first_value) -- [last_value](#last_value) - -### `avg` - -Returns the average of numeric values in the specified column. - -``` -avg(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -#### Aliases - -- `mean` - -### `bit_and` - -Computes the bitwise AND of all non-null input values. - -``` -bit_and(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bit_or` - -Computes the bitwise OR of all non-null input values. - -``` -bit_or(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bit_xor` - -Computes the bitwise exclusive OR of all non-null input values. - -``` -bit_xor(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bool_and` - -Returns true if all non-null input values are true, otherwise false. - -``` -bool_and(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bool_or` - -Returns true if any non-null input value is true, otherwise false. - -``` -bool_or(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `count` - -Returns the number of non-null values in the specified column. - -To include _null_ values in the total count, use `count(*)`. - -``` -count(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `max` - -Returns the maximum value in the specified column. - -``` -max(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `mean` - -_Alias of [avg](#avg)._ - -### `median` - -Returns the median value in the specified column. - -``` -median(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `min` - -Returns the minimum value in the specified column. - -``` -min(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sum` - -Returns the sum of all values in the specified column. - -``` -sum(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `array_agg` - -Returns an array created from the expression elements. If ordering requirement is given, elements are inserted in the order of required ordering. - -``` -array_agg(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `first_value` - -Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. - -``` -first_value(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `last_value` - -Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. - -``` -last_value(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -## Statistical - -- [corr](#corr) -- [covar](#covar) -- [covar_pop](#covar_pop) -- [covar_samp](#covar_samp) -- [stddev](#stddev) -- [stddev_pop](#stddev_pop) -- [stddev_samp](#stddev_samp) -- [var](#var) -- [var_pop](#var_pop) -- [var_samp](#var_samp) -- [regr_avgx](#regr_avgx) -- [regr_avgy](#regr_avgy) -- [regr_count](#regr_count) -- [regr_intercept](#regr_intercept) -- [regr_r2](#regr_r2) -- [regr_slope](#regr_slope) -- [regr_sxx](#regr_sxx) -- [regr_syy](#regr_syy) -- [regr_sxy](#regr_sxy) -- [kurtosis_pop](#kurtosis_pop) - -### `corr` - -Returns the coefficient of correlation between two numeric values. - -``` -corr(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `covar` - -Returns the covariance of a set of number pairs. - -``` -covar(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `covar_pop` - -Returns the population covariance of a set of number pairs. - -``` -covar_pop(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `covar_samp` - -Returns the sample covariance of a set of number pairs. - -``` -covar_samp(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `stddev` - -Returns the standard deviation of a set of numbers. - -``` -stddev(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `stddev_pop` - -Returns the population standard deviation of a set of numbers. - -``` -stddev_pop(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `stddev_samp` - -Returns the sample standard deviation of a set of numbers. - -``` -stddev_samp(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var` - -Returns the statistical variance of a set of numbers. - -``` -var(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var_pop` - -Returns the statistical population variance of a set of numbers. - -``` -var_pop(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var_samp` - -Returns the statistical sample variance of a set of numbers. - -``` -var_samp(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_slope` - -Returns the slope of the linear regression line for non-null pairs in aggregate columns. -Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. - -``` -regr_slope(expression1, expression2) -``` - -#### Arguments - -- **expression_y**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_avgx` - -Computes the average of the independent variable (input) `expression_x` for the non-null paired data points. - -``` -regr_avgx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_avgy` - -Computes the average of the dependent variable (output) `expression_y` for the non-null paired data points. - -``` -regr_avgy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_count` - -Counts the number of non-null paired data points. - -``` -regr_count(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_intercept` - -Computes the y-intercept of the linear regression line. For the equation \(y = kx + b\), this function returns `b`. - -``` -regr_intercept(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_r2` - -Computes the square of the correlation coefficient between the independent and dependent variables. - -``` -regr_r2(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_sxx` - -Computes the sum of squares of the independent variable. - -``` -regr_sxx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_syy` - -Computes the sum of squares of the dependent variable. - -``` -regr_syy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_sxy` - -Computes the sum of products of paired data points. - -``` -regr_sxy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `kurtosis_pop` - -Computes the excess kurtosis (Fisher’s definition) without bias correction. - -``` -kurtois_pop(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -## Approximate - -- [approx_distinct](#approx_distinct) -- [approx_median](#approx_median) -- [approx_percentile_cont](#approx_percentile_cont) -- [approx_percentile_cont_with_weight](#approx_percentile_cont_with_weight) - -### `approx_distinct` - -Returns the approximate number of distinct input values calculated using the -HyperLogLog algorithm. - -``` -approx_distinct(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `approx_median` - -Returns the approximate median (50th percentile) of input values. -It is an alias of `approx_percentile_cont(x, 0.5)`. - -``` -approx_median(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `approx_percentile_cont` - -Returns the approximate percentile of input values using the t-digest algorithm. - -``` -approx_percentile_cont(expression, percentile, centroids) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). -- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. - - If there are this number or fewer unique values, you can expect an exact result. - A higher number of centroids results in a more accurate approximation, but - requires more memory to compute. - -### `approx_percentile_cont_with_weight` - -Returns the weighted approximate percentile of input values using the -t-digest algorithm. - -``` -approx_percentile_cont_with_weight(expression, weight, percentile) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **weight**: Expression to use as weight. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +Note: this documentation has been migrated to [Aggregate Functions (new)](aggregate_functions_new.md) diff --git a/docs/source/user-guide/sql/aggregate_functions_new.md b/docs/source/user-guide/sql/aggregate_functions_new.md new file mode 100644 index 0000000000000..ad6d15b94ee53 --- /dev/null +++ b/docs/source/user-guide/sql/aggregate_functions_new.md @@ -0,0 +1,865 @@ + + + + +# Aggregate Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Aggregate Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +Aggregate functions operate on a set of values to compute a single result. + +## General Functions + +- [array_agg](#array_agg) +- [avg](#avg) +- [bit_and](#bit_and) +- [bit_or](#bit_or) +- [bit_xor](#bit_xor) +- [bool_and](#bool_and) +- [bool_or](#bool_or) +- [count](#count) +- [first_value](#first_value) +- [grouping](#grouping) +- [last_value](#last_value) +- [max](#max) +- [mean](#mean) +- [median](#median) +- [min](#min) +- [string_agg](#string_agg) +- [sum](#sum) +- [var](#var) +- [var_pop](#var_pop) +- [var_population](#var_population) +- [var_samp](#var_samp) +- [var_sample](#var_sample) + +### `array_agg` + +Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order. + +``` +array_agg(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +``` + +### `avg` + +Returns the average of numeric values in the specified column. + +``` +avg(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +``` + +#### Aliases + +- mean + +### `bit_and` + +Computes the bitwise AND of all non-null input values. + +``` +bit_and(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bit_or` + +Computes the bitwise OR of all non-null input values. + +``` +bit_or(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bit_xor` + +Computes the bitwise exclusive OR of all non-null input values. + +``` +bit_xor(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bool_and` + +Returns true if all non-null input values are true, otherwise false. + +``` +bool_and(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` + +### `bool_or` + +Returns true if all non-null input values are true, otherwise false. + +``` +bool_and(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` + +### `count` + +Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`. + +``` +count(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +``` + +### `first_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` + +### `grouping` + +Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. + +``` +grouping(expression) +``` + +#### Arguments + +- **expression**: Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +``` + +### `last_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` + +### `max` + +Returns the maximum value in the specified column. + +``` +max(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` + +### `mean` + +_Alias of [avg](#avg)._ + +### `median` + +Returns the median value in the specified column. + +``` +median(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +``` + +### `min` + +Returns the maximum value in the specified column. + +``` +max(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` + +### `string_agg` + +Concatenates the values of string expressions and places separator values between them. + +``` +string_agg(expression, delimiter) +``` + +#### Arguments + +- **expression**: The string expression to concatenate. Can be a column or any valid string expression. +- **delimiter**: A literal string used as a separator between the concatenated values. + +#### Example + +```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +``` + +### `sum` + +Returns the sum of all values in the specified column. + +``` +sum(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +``` + +### `var` + +Returns the statistical sample variance of a set of numbers. + +``` +var(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_sample +- var_samp + +### `var_pop` + +Returns the statistical population variance of a set of numbers. + +``` +var_pop(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_population + +### `var_population` + +_Alias of [var_pop](#var_pop)._ + +### `var_samp` + +_Alias of [var](#var)._ + +### `var_sample` + +_Alias of [var](#var)._ + +## Statistical Functions + +- [corr](#corr) +- [covar](#covar) +- [covar_pop](#covar_pop) +- [covar_samp](#covar_samp) +- [nth_value](#nth_value) +- [regr_avgx](#regr_avgx) +- [regr_avgy](#regr_avgy) +- [regr_count](#regr_count) +- [regr_intercept](#regr_intercept) +- [regr_r2](#regr_r2) +- [regr_slope](#regr_slope) +- [regr_sxx](#regr_sxx) +- [regr_sxy](#regr_sxy) +- [regr_syy](#regr_syy) +- [stddev](#stddev) +- [stddev_pop](#stddev_pop) +- [stddev_samp](#stddev_samp) + +### `corr` + +Returns the coefficient of correlation between two numeric values. + +``` +corr(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ +``` + +### `covar` + +_Alias of [covar_samp](#covar_samp)._ + +### `covar_pop` + +Returns the sample covariance of a set of number pairs. + +``` +covar_samp(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +``` + +### `covar_samp` + +Returns the sample covariance of a set of number pairs. + +``` +covar_samp(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +``` + +#### Aliases + +- covar + +### `nth_value` + +Returns the nth value in a group of values. + +``` +nth_value(expression, n ORDER BY expression) +``` + +#### Arguments + +- **expression**: The column or expression to retrieve the nth value from. +- **n**: The position (nth) of the value to retrieve, based on the ordering. + +#### Example + +```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +``` + +### `regr_avgx` + +Computes the average of the independent variable (input) expression_x for the non-null paired data points. + +``` +regr_avgx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_avgy` + +Computes the average of the dependent variable (output) expression_y for the non-null paired data points. + +``` +regr_avgy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_count` + +Counts the number of non-null paired data points. + +``` +regr_count(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_intercept` + +Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. + +``` +regr_intercept(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_r2` + +Computes the square of the correlation coefficient between the independent and dependent variables. + +``` +regr_r2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_slope` + +Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. + +``` +regr_slope(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxx` + +Computes the sum of squares of the independent variable. + +``` +regr_sxx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxy` + +Computes the sum of products of paired data points. + +``` +regr_sxy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_syy` + +Computes the sum of squares of the dependent variable. + +``` +regr_syy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `stddev` + +Returns the standard deviation of a set of numbers. + +``` +stddev(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +``` + +#### Aliases + +- stddev_samp + +### `stddev_pop` + +Returns the standard deviation of a set of numbers. + +``` +stddev(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +``` + +### `stddev_samp` + +_Alias of [stddev](#stddev)._ + +## Approximate Functions + +- [approx_distinct](#approx_distinct) +- [approx_median](#approx_median) +- [approx_percentile_cont](#approx_percentile_cont) +- [approx_percentile_cont_with_weight](#approx_percentile_cont_with_weight) + +### `approx_distinct` + +Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm. + +``` +approx_distinct(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +``` + +### `approx_median` + +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. + +``` +approx_median(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +``` + +### `approx_percentile_cont` + +Returns the approximate percentile of input values using the t-digest algorithm. + +``` +approx_percentile_cont(expression, percentile, centroids) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. + +#### Example + +```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +``` + +### `approx_percentile_cont_with_weight` + +Returns the weighted approximate percentile of input values using the t-digest algorithm. + +``` +approx_percentile_cont_with_weight(expression, weight, percentile) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). + +#### Example + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 0e974550a84dc..18c95cdea70ed 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -97,7 +97,7 @@ select arrow_cast(now(), 'Timestamp(Second, None)'); | `BYTEA` | `Binary` | You can create binary literals using a hex string literal such as -`X'1234` to create a `Binary` value of two bytes, `0x12` and `0x34`. +`X'1234'` to create a `Binary` value of two bytes, `0x12` and `0x34`. ## Unsupported SQL Types diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 04d1fc228f816..8b8afc7b048aa 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -30,7 +30,11 @@ SQL Reference information_schema operators aggregate_functions + aggregate_functions_new window_functions + window_functions_new scalar_functions + scalar_functions_new + special_functions sql_status write_options diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 469fb705b71f4..a8e25930bef7a 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -19,3534 +19,72 @@ # Scalar Functions -## Math Functions +Scalar functions operate on a single row at a time and return a single value. -- [abs](#abs) -- [acos](#acos) -- [acosh](#acosh) -- [asin](#asin) -- [asinh](#asinh) -- [atan](#atan) -- [atanh](#atanh) -- [atan2](#atan2) -- [cbrt](#cbrt) -- [ceil](#ceil) -- [cos](#cos) -- [cosh](#cosh) -- [degrees](#degrees) -- [exp](#exp) -- [factorial](#factorial) -- [floor](#floor) -- [gcd](#gcd) -- [isnan](#isnan) -- [iszero](#iszero) -- [lcm](#lcm) -- [ln](#ln) -- [log](#log) -- [log10](#log10) -- [log2](#log2) -- [nanvl](#nanvl) -- [pi](#pi) -- [power](#power) -- [pow](#pow) -- [radians](#radians) -- [random](#random) -- [round](#round) -- [signum](#signum) -- [sin](#sin) -- [sinh](#sinh) -- [sqrt](#sqrt) -- [tan](#tan) -- [tanh](#tanh) -- [trunc](#trunc) +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (new)](scalar_functions_new.md) page for +the rest of the documentation. -### `abs` +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 -Returns the absolute value of a number. - -``` -abs(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `acos` - -Returns the arc cosine or inverse cosine of a number. - -``` -acos(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `acosh` - -Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number. - -``` -acosh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `asin` - -Returns the arc sine or inverse sine of a number. - -``` -asin(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `asinh` - -Returns the area hyperbolic sine or inverse hyperbolic sine of a number. - -``` -asinh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `atan` - -Returns the arc tangent or inverse tangent of a number. - -``` -atan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `atanh` - -Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. - -``` -atanh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `atan2` - -Returns the arc tangent or inverse tangent of `expression_y / expression_x`. - -``` -atan2(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `cbrt` - -Returns the cube root of a number. - -``` -cbrt(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `ceil` - -Returns the nearest integer greater than or equal to a number. - -``` -ceil(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `cos` - -Returns the cosine of a number. - -``` -cos(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `cosh` - -Returns the hyperbolic cosine of a number. - -``` -cosh(numeric_expression) -``` - -### `degrees` - -Converts radians to degrees. - -``` -degrees(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `exp` - -Returns the base-e exponential of a number. - -``` -exp(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to use as the exponent. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `factorial` - -Factorial. Returns 1 if value is less than 2. - -``` -factorial(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `floor` - -Returns the nearest integer less than or equal to a number. - -``` -floor(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `gcd` - -Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero. - -``` -gcd(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `isnan` - -Returns true if a given number is +NaN or -NaN otherwise returns false. - -``` -isnan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `iszero` - -Returns true if a given number is +0.0 or -0.0 otherwise returns false. - -``` -iszero(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `lcm` - -Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. - -``` -lcm(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `ln` - -Returns the natural logarithm of a number. - -``` -ln(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `log` - -Returns the base-x logarithm of a number. -Can either provide a specified base, or if omitted then takes the base-10 of a number. - -``` -log(base, numeric_expression) -log(numeric_expression) -``` - -#### Arguments - -- **base**: Base numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `log10` - -Returns the base-10 logarithm of a number. - -``` -log10(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `log2` - -Returns the base-2 logarithm of a number. - -``` -log2(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `nanvl` - -Returns the first argument if it's not _NaN_. -Returns the second argument otherwise. - -``` -nanvl(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: Numeric expression to return if it's not _NaN_. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Numeric expression to return if the first expression is _NaN_. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `pi` - -Returns an approximate value of π. - -``` -pi() -``` - -### `power` - -Returns a base expression raised to the power of an exponent. - -``` -power(base, exponent) -``` - -#### Arguments - -- **base**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **exponent**: Exponent numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -#### Aliases - -- pow - -### `pow` - -_Alias of [power](#power)._ - -### `radians` - -Converts degrees to radians. - -``` -radians(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `random` - -Returns a random float value in the range [0, 1). -The random seed is unique to each row. - -``` -random() -``` - -### `round` - -Rounds a number to the nearest integer. - -``` -round(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **decimal_places**: Optional. The number of decimal places to round to. - Defaults to 0. - -### `signum` - -Returns the sign of a number. -Negative numbers return `-1`. -Zero and positive numbers return `1`. - -``` -signum(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sin` - -Returns the sine of a number. - -``` -sin(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sinh` - -Returns the hyperbolic sine of a number. - -``` -sinh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sqrt` - -Returns the square root of a number. - -``` -sqrt(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `tan` - -Returns the tangent of a number. - -``` -tan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `tanh` - -Returns the hyperbolic tangent of a number. - -``` -tanh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `trunc` - -Truncates a number to a whole number or truncated to the specified decimal places. - -``` -trunc(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -- **decimal_places**: Optional. The number of decimal places to - truncate to. Defaults to 0 (truncate to a whole number). If - `decimal_places` is a positive integer, truncates digits to the - right of the decimal point. If `decimal_places` is a negative - integer, replaces digits to the left of the decimal point with `0`. - -## Conditional Functions - -- [coalesce](#coalesce) -- [nullif](#nullif) -- [nvl](#nvl) -- [nvl2](#nvl2) -- [ifnull](#ifnull) - -### `coalesce` - -Returns the first of its arguments that is not _null_. -Returns _null_ if all arguments are _null_. -This function is often used to substitute a default value for _null_ values. - -``` -coalesce(expression1[, ..., expression_n]) -``` - -#### Arguments - -- **expression1, expression_n**: - Expression to use if previous expressions are _null_. - Can be a constant, column, or function, and any combination of arithmetic operators. - Pass as many expression arguments as necessary. - -### `nullif` - -Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. -This can be used to perform the inverse operation of [`coalesce`](#coalesce). - -``` -nullif(expression1, expression2) -``` - -#### Arguments - -- **expression1**: Expression to compare and return if equal to expression2. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Expression to compare to expression1. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `nvl` - -Returns _expression2_ if _expression1_ is NULL; otherwise it returns _expression1_. - -``` -nvl(expression1, expression2) -``` - -#### Arguments - -- **expression1**: return if expression1 not is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `nvl2` - -Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. - -``` -nvl2(expression1, expression2, expression3) -``` - -#### Arguments - -- **expression1**: conditional expression. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is not NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression3**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `ifnull` - -_Alias of [nvl](#nvl)._ - -## String Functions - -- [ascii](#ascii) -- [bit_length](#bit_length) -- [btrim](#btrim) -- [char_length](#char_length) -- [character_length](#character_length) -- [concat](#concat) -- [concat_ws](#concat_ws) -- [chr](#chr) -- [ends_with](#ends_with) -- [initcap](#initcap) -- [instr](#instr) -- [left](#left) -- [length](#length) -- [lower](#lower) -- [lpad](#lpad) -- [ltrim](#ltrim) -- [octet_length](#octet_length) -- [repeat](#repeat) -- [replace](#replace) -- [reverse](#reverse) -- [right](#right) -- [rpad](#rpad) -- [rtrim](#rtrim) -- [split_part](#split_part) -- [starts_with](#starts_with) -- [strpos](#strpos) -- [substr](#substr) -- [to_hex](#to_hex) -- [translate](#translate) -- [trim](#trim) -- [upper](#upper) -- [uuid](#uuid) -- [overlay](#overlay) -- [levenshtein](#levenshtein) -- [substr_index](#substr_index) -- [find_in_set](#find_in_set) -- [position](#position) -- [contains](#contains) - -### `ascii` - -Returns the ASCII value of the first character in a string. - -``` -ascii(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[chr](#chr) - -### `bit_length` - -Returns the bit length of a string. - -``` -bit_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[length](#length), -[octet_length](#octet_length) - -### `btrim` - -Trims the specified trim string from the start and end of a string. -If no trim string is provided, all whitespace is removed from the start and end -of the input string. - -``` -btrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning and end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[ltrim](#ltrim), -[rtrim](#rtrim) - -#### Aliases - -- trim - -### `char_length` - -_Alias of [length](#length)._ - -### `character_length` - -_Alias of [length](#length)._ - -### `concat` - -Concatenates multiple strings together. - -``` -concat(str[, ..., str_n]) -``` - -#### Arguments - -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. - -**Related functions**: -[concat_ws](#concat_ws) - -### `concat_ws` - -Concatenates multiple strings together with a specified separator. - -``` -concat_ws(separator, str[, ..., str_n]) -``` - -#### Arguments - -- **separator**: Separator to insert between concatenated strings. -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. - -**Related functions**: -[concat](#concat) - -### `chr` - -Returns the character with the specified ASCII or Unicode code value. - -``` -chr(expression) -``` - -#### Arguments - -- **expression**: Expression containing the ASCII or Unicode code value to operate on. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. - -**Related functions**: -[ascii](#ascii) - -### `ends_with` - -Tests if a string ends with a substring. - -``` -ends_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring to test for. - -### `initcap` - -Capitalizes the first character in each word in the input string. -Words are delimited by non-alphanumeric characters. - -``` -initcap(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[lower](#lower), -[upper](#upper) - -### `instr` - -_Alias of [strpos](#strpos)._ - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. - -### `left` - -Returns a specified number of characters from the left side of a string. - -``` -left(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. - -**Related functions**: -[right](#right) - -### `length` - -Returns the number of characters in a string. - -``` -length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -#### Aliases - -- char_length -- character_length - -**Related functions**: -[bit_length](#bit_length), -[octet_length](#octet_length) - -### `lower` - -Converts a string to lower-case. - -``` -lower(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[initcap](#initcap), -[upper](#upper) - -### `lpad` - -Pads the left side of a string with another string to a specified string length. - -``` -lpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ - -**Related functions**: -[rpad](#rpad) - -### `ltrim` - -Trims the specified trim string from the beginning of a string. -If no trim string is provided, all whitespace is removed from the start -of the input string. - -``` -ltrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[btrim](#btrim), -[rtrim](#rtrim) - -### `octet_length` - -Returns the length of a string in bytes. - -``` -octet_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[bit_length](#bit_length), -[length](#length) - -### `repeat` - -Returns a string with an input string repeated a specified number. - -``` -repeat(str, n) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of times to repeat the input string. - -### `replace` - -Replaces all occurrences of a specified substring in a string with a new substring. - -``` -replace(str, substr, replacement) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to replace in the input string. - Can be a constant, column, or function, and any combination of string operators. -- **replacement**: Replacement substring expression. - Can be a constant, column, or function, and any combination of string operators. - -### `reverse` - -Reverses the character order of a string. - -``` -reverse(str) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. - -### `right` - -Returns a specified number of characters from the right side of a string. - -``` -right(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. - -**Related functions**: -[left](#left) - -### `rpad` - -Pads the right side of a string with another string to a specified string length. - -``` -rpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ - -**Related functions**: -[lpad](#lpad) - -### `rtrim` - -Trims the specified trim string from the end of a string. -If no trim string is provided, all whitespace is removed from the end -of the input string. - -``` -rtrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[btrim](#btrim), -[ltrim](#ltrim) - -### `split_part` - -Splits a string based on a specified delimiter and returns the substring in the -specified position. - -``` -split_part(str, delimiter, pos) -``` - -#### Arguments - -- **str**: String expression to spit. - Can be a constant, column, or function, and any combination of string operators. -- **delimiter**: String or character to split on. -- **pos**: Position of the part to return. - -### `starts_with` - -Tests if a string starts with a substring. - -``` -starts_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring to test for. - -### `strpos` - -Returns the starting position of a specified substring in a string. -Positions begin at 1. -If the substring does not exist in the string, the function returns 0. - -``` -strpos(str, substr) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. - -#### Aliases - -- instr - -### `substr` - -Extracts a substring of a specified number of characters from a specific -starting position in a string. - -``` -substr(str, start_pos[, length]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **start_pos**: Character position to start the substring at. - The first character in the string has a position of 1. -- **length**: Number of characters to extract. - If not specified, returns the rest of the string after the start position. - -#### Aliases - -- substring - -### `substring` - -_Alias of [substr](#substr)._ - -### `translate` - -Translates characters in a string to specified translation characters. - -``` -translate(str, chars, translation) -``` - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only - characters at the same position in the **chars** string. - -### `to_hex` - -Converts an integer to a hexadecimal string. - -``` -to_hex(int) -``` - -#### Arguments - -- **int**: Integer expression to convert. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `trim` - -_Alias of [btrim](#btrim)._ - -### `upper` - -Converts a string to upper-case. - -``` -upper(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[initcap](#initcap), -[lower](#lower) - -### `uuid` - -Returns UUID v4 string value which is unique per row. - -``` -uuid() -``` - -### `overlay` - -Returns the string which is replaced by another string from the specified position and specified count length. -For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` - -``` -overlay(str PLACING substr FROM pos [FOR count]) -``` - -#### Arguments - -- **str**: String expression to operate on. -- **substr**: the string to replace part of str. -- **pos**: the start position to replace of str. -- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. - -### `levenshtein` - -Returns the Levenshtein distance between the two given strings. -For example, `levenshtein('kitten', 'sitting') = 3` - -``` -levenshtein(str1, str2) -``` - -#### Arguments - -- **str1**: String expression to compute Levenshtein distance with str2. -- **str2**: String expression to compute Levenshtein distance with str1. - -### `substr_index` - -Returns the substring from str before count occurrences of the delimiter delim. -If count is positive, everything to the left of the final delimiter (counting from the left) is returned. -If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` - -``` -substr_index(str, delim, count) -``` - -#### Arguments - -- **str**: String expression to operate on. -- **delim**: the string to find in str to split str. -- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. - -### `find_in_set` - -Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. -For example, `find_in_set('b', 'a,b,c,d') = 2` - -``` -find_in_set(str, strlist) -``` - -#### Arguments - -- **str**: String expression to find in strlist. -- **strlist**: A string list is a string composed of substrings separated by , characters. - -## Binary String Functions - -- [decode](#decode) -- [encode](#encode) - -### `encode` - -Encode binary data into a textual representation. - -``` -encode(expression, format) -``` - -#### Arguments - -- **expression**: Expression containing string or binary data - -- **format**: Supported formats are: `base64`, `hex` - -**Related functions**: -[decode](#decode) - -### `decode` - -Decode binary data from textual representation in string. - -``` -decode(expression, format) -``` - -#### Arguments - -- **expression**: Expression containing encoded string data - -- **format**: Same arguments as [encode](#encode) - -**Related functions**: -[encode](#encode) - -## Regular Expression Functions - -Apache DataFusion uses a [PCRE-like] regular expression [syntax] -(minus support for several features including look-around and backreferences). -The following regular expression functions are supported: - -- [regexp_like](#regexp_like) -- [regexp_match](#regexp_match) -- [regexp_replace](#regexp_replace) - -[pcre-like]: https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions -[syntax]: https://docs.rs/regex/latest/regex/#syntax - -### `regexp_like` - -Returns true if a [regular expression] has at least one match in a string, -false otherwise. - -[regular expression]: https://docs.rs/regex/latest/regex/#syntax - -``` -regexp_like(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to test against the string expression. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+--------------------------------------------------------+ -| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+--------------------------------------------------------+ -| true | -+--------------------------------------------------------+ -SELECT regexp_like('aBc', '(b|d)', 'i'); -+--------------------------------------------------+ -| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+--------------------------------------------------+ -| true | -+--------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_match` - -Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. - -``` -regexp_match(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+---------------------------------------------------------+ -| regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+---------------------------------------------------------+ -| [Köln] | -+---------------------------------------------------------+ -SELECT regexp_match('aBc', '(b|d)', 'i'); -+---------------------------------------------------+ -| regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+---------------------------------------------------+ -| [B] | -+---------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_replace` - -Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). - -``` -regexp_replace(str, regexp, replacement[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **replacement**: Replacement string expression. - Can be a constant, column, or function, and any combination of string operators. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **g**: (global) Search globally and don't return after the first match - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); -+------------------------------------------------------------------------+ -| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | -+------------------------------------------------------------------------+ -| fooXarYXazY | -+------------------------------------------------------------------------+ -SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); -+-------------------------------------------------------------------+ -| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | -+-------------------------------------------------------------------+ -| aAbBac | -+-------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `position` - -Returns the position of `substr` in `origstr` (counting from 1). If `substr` does -not appear in `origstr`, return 0. - -``` -position(substr in origstr) -``` - -#### Arguments - -- **substr**: The pattern string. -- **origstr**: The model string. - -### `contains` - -Return true if search_string is found within string (case-sensitive). - -``` -contains(string, search_string) -``` - -#### Arguments - -- **string**: The pattern string. -- **search_string**: The model string. - -## Time and Date Functions - -- [now](#now) -- [current_date](#current_date) -- [current_time](#current_time) -- [date_bin](#date_bin) -- [date_trunc](#date_trunc) -- [datetrunc](#datetrunc) -- [date_part](#date_part) -- [datepart](#datepart) -- [extract](#extract) -- [today](#today) -- [make_date](#make_date) -- [to_char](#to_char) -- [to_date](#to_date) -- [to_local_time](#to_local_time) -- [to_timestamp](#to_timestamp) -- [to_timestamp_millis](#to_timestamp_millis) -- [to_timestamp_micros](#to_timestamp_micros) -- [to_timestamp_seconds](#to_timestamp_seconds) -- [to_timestamp_nanos](#to_timestamp_nanos) -- [from_unixtime](#from_unixtime) -- [to_unixtime](#to_unixtime) - -### `now` - -Returns the current UTC timestamp. - -The `now()` return value is determined at query time and will return the same timestamp, -no matter when in the query plan the function executes. - -``` -now() -``` - -### `current_date` - -Returns the current UTC date. - -The `current_date()` return value is determined at query time and will return the same date, -no matter when in the query plan the function executes. - -``` -current_date() -``` - -#### Aliases - -- today - -### `today` - -_Alias of [current_date](#current_date)._ - -### `current_time` - -Returns the current UTC time. - -The `current_time()` return value is determined at query time and will return the same time, -no matter when in the query plan the function executes. - -``` -current_time() -``` - -### `date_bin` - -Calculates time intervals and returns the start of the interval nearest to the specified timestamp. -Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" -and applying an aggregate or selector function to each window. - -For example, if you "bin" or "window" data into 15 minute intervals, an input -timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 -minute bin it is in: `2023-01-01T18:15:00Z`. - -``` -date_bin(interval, expression, origin-timestamp) -``` - -#### Arguments - -- **interval**: Bin interval. -- **expression**: Time expression to operate on. - Can be a constant, column, or function. -- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified - defaults `1970-01-01T00:00:00Z` (the UNIX epoch in UTC). - -The following intervals are supported: - -- nanoseconds -- microseconds -- milliseconds -- seconds -- minutes -- hours -- days -- weeks -- months -- years -- century - -### `date_trunc` - -Truncates a timestamp value to a specified precision. - -``` -date_trunc(precision, expression) -``` - -#### Arguments - -- **precision**: Time precision to truncate to. - The following precisions are supported: - - - year / YEAR - - quarter / QUARTER - - month / MONTH - - week / WEEK - - day / DAY - - hour / HOUR - - minute / MINUTE - - second / SECOND - -- **expression**: Time expression to operate on. - Can be a constant, column, or function. - -#### Aliases - -- datetrunc - -### `datetrunc` - -_Alias of [date_trunc](#date_trunc)._ - -### `date_part` - -Returns the specified part of the date as an integer. - -``` -date_part(part, expression) -``` - -#### Arguments - -- **part**: Part of the date to return. - The following date parts are supported: - - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ - -- **expression**: Time expression to operate on. - Can be a constant, column, or function. - -#### Aliases - -- datepart - -### `datepart` - -_Alias of [date_part](#date_part)._ - -### `extract` - -Returns a sub-field from a time value as an integer. - -``` -extract(field FROM source) -``` - -Equivalent to calling `date_part('field', source)`. For example, these are equivalent: - -```sql -extract(day FROM '2024-04-13'::date) -date_part('day', '2024-04-13'::date) -``` - -See [date_part](#date_part). - -### `make_date` - -Make a date from year/month/day component parts. - -``` -make_date(year, month, day) -``` - -#### Arguments - -- **year**: Year to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **month**: Month to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **day**: Day to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. - -#### Example - -``` -> select make_date(2023, 1, 31); -+-------------------------------------------+ -| make_date(Int64(2023),Int64(1),Int64(31)) | -+-------------------------------------------+ -| 2023-01-31 | -+-------------------------------------------+ -> select make_date('2023', '01', '31'); -+-----------------------------------------------+ -| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | -+-----------------------------------------------+ -| 2023-01-31 | -+-----------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) - -### `to_char` - -Returns a string representation of a date, time, timestamp or duration based -on a [Chrono format]. Unlike the PostgreSQL equivalent of this function -numerical formatting is not supported. - -``` -to_char(expression, format) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function that results in a - date, time, timestamp or duration. -- **format**: A [Chrono format] string to use to convert the expression. - -#### Example - -``` -> select to_char('2023-03-01'::date, '%d-%m-%Y'); -+----------------------------------------------+ -| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | -+----------------------------------------------+ -| 01-03-2023 | -+----------------------------------------------+ -``` - -Additional examples can be found [here] - -[here]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs - -#### Aliases - -- date_format - -### `to_date` - -Converts a value to a date (`YYYY-MM-DD`). -Supports strings, integer and double types as input. -Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format]s are provided. -Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding date. - -Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. - -``` -to_date(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html - -#### Example - -``` -> select to_date('2023-01-31'); -+-----------------------------+ -| to_date(Utf8("2023-01-31")) | -+-----------------------------+ -| 2023-01-31 | -+-----------------------------+ -> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); -+---------------------------------------------------------------+ -| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | -+---------------------------------------------------------------+ -| 2023-01-31 | -+---------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) - -### `to_local_time` - -Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or -timezone information). This function handles daylight saving time changes. - -``` -to_local_time(expression) -``` - -#### Arguments - -- **expression**: Time expression to operate on. Can be a constant, column, or function. - -#### Example - -``` -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT - time, - arrow_typeof(time) as type, - to_local_time(time) as to_local_time, - arrow_typeof(to_local_time(time)) as to_local_time_type -FROM ( - SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time -); -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| time | type | to_local_time | to_local_time_type | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ - -# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather -# than UTC boundaries - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; -+---------------------+ -| date_bin | -+---------------------+ -| 2024-04-01T00:00:00 | -+---------------------+ - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; -+---------------------------+ -| date_bin_with_timezone | -+---------------------------+ -| 2024-04-01T00:00:00+02:00 | -+---------------------------+ -``` - -### `to_timestamp` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, unsigned integer, and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. -Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. -Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` -for the input outside of supported bounds. - -``` -to_timestamp(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html - -#### Example - -``` -> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------+ -| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------+ -> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------+ -| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+--------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_millis` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_millis(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123 | -+------------------------------------------------------------------+ -> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_micros` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) -Returns the corresponding timestamp. - -``` -to_timestamp_micros(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456 | -+------------------------------------------------------------------+ -> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_nanos` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_nanos(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------------+ -> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_seconds` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_seconds(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); -+-------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-------------------------------------------------------------------+ -| 2023-01-31T14:26:56 | -+-------------------------------------------------------------------+ -> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+----------------------------------------------------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+----------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00 | -+----------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `from_unixtime` - -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. - -``` -from_unixtime(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `to_unixtime` - -Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Supports strings, dates, timestamps and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. - -``` -to_unixtime(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_unixtime('2020-09-08T12:00:00+00:00'); -+------------------------------------------------+ -| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | -+------------------------------------------------+ -| 1599566400 | -+------------------------------------------------+ -> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); -+-----------------------------------------------------------------------------------------------------------------------------+ -| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | -+-----------------------------------------------------------------------------------------------------------------------------+ -| 1673638290 | -+-----------------------------------------------------------------------------------------------------------------------------+ -``` - -## Array Functions - -- [array_any_value](#array_any_value) -- [array_append](#array_append) -- [array_sort](#array_sort) -- [array_cat](#array_cat) -- [array_concat](#array_concat) -- [array_contains](#array_contains) -- [array_dims](#array_dims) -- [array_distance](#array_distance) -- [array_distinct](#array_distinct) -- [array_has](#array_has) -- [array_has_all](#array_has_all) -- [array_has_any](#array_has_any) -- [array_element](#array_element) -- [array_empty](#array_empty) -- [array_except](#array_except) -- [array_extract](#array_extract) -- [array_fill](#array_fill) -- [array_indexof](#array_indexof) -- [array_intersect](#array_intersect) -- [array_join](#array_join) -- [array_length](#array_length) -- [array_ndims](#array_ndims) -- [array_prepend](#array_prepend) -- [array_pop_front](#array_pop_front) -- [array_pop_back](#array_pop_back) -- [array_position](#array_position) -- [array_positions](#array_positions) -- [array_push_back](#array_push_back) -- [array_push_front](#array_push_front) -- [array_repeat](#array_repeat) -- [array_resize](#array_resize) -- [array_remove](#array_remove) -- [array_remove_n](#array_remove_n) -- [array_remove_all](#array_remove_all) -- [array_replace](#array_replace) -- [array_replace_n](#array_replace_n) -- [array_replace_all](#array_replace_all) -- [array_reverse](#array_reverse) -- [array_slice](#array_slice) -- [array_to_string](#array_to_string) -- [array_union](#array_union) -- [cardinality](#cardinality) -- [empty](#empty) -- [flatten](#flatten) -- [generate_series](#generate_series) -- [list_any_value](#list_any_value) -- [list_append](#list_append) -- [list_sort](#list_sort) -- [list_cat](#list_cat) -- [list_concat](#list_concat) -- [list_dims](#list_dims) -- [list_distance](#list_distance) -- [list_distinct](#list_distinct) -- [list_element](#list_element) -- [list_except](#list_except) -- [list_extract](#list_extract) -- [list_has](#list_has) -- [list_has_all](#list_has_all) -- [list_has_any](#list_has_any) -- [list_indexof](#list_indexof) -- [list_intersect](#list_intersect) -- [list_join](#list_join) -- [list_length](#list_length) -- [list_ndims](#list_ndims) -- [list_prepend](#list_prepend) -- [list_pop_back](#list_pop_back) -- [list_pop_front](#list_pop_front) -- [list_position](#list_position) -- [list_positions](#list_positions) -- [list_push_back](#list_push_back) -- [list_push_front](#list_push_front) -- [list_repeat](#list_repeat) -- [list_resize](#list_resize) -- [list_remove](#list_remove) -- [list_remove_n](#list_remove_n) -- [list_remove_all](#list_remove_all) -- [list_replace](#list_replace) -- [list_replace_n](#list_replace_n) -- [list_replace_all](#list_replace_all) -- [list_slice](#list_slice) -- [list_to_string](#list_to_string) -- [list_union](#list_union) -- [make_array](#make_array) -- [make_list](#make_list) -- [string_to_array](#string_to_array) -- [string_to_list](#string_to_list) -- [trim_array](#trim_array) -- [unnest](#unnest) -- [range](#range) - -### `array_any_value` - -Returns the first non-null element in the array. - -``` -array_any_value(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_any_value([NULL, 1, 2, 3]); -+--------------------------------------------------------------+ -| array_any_value(List([NULL,1,2,3])) | -+--------------------------------------------------------------+ -| 1 | -+--------------------------------------------------------------+ -``` - -### `array_append` - -Appends an element to the end of an array. - -``` -array_append(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to append to the array. - -#### Example - -``` -> select array_append([1, 2, 3], 4); -+--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ -``` - -#### Aliases - -- array_push_back -- list_append -- list_push_back - -### `array_sort` - -Sort array. - -``` -array_sort(array, desc, nulls_first) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **desc**: Whether to sort in descending order(`ASC` or `DESC`). -- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). - -#### Example - -``` -> select array_sort([3, 1, 2]); -+-----------------------------+ -| array_sort(List([3,1,2])) | -+-----------------------------+ -| [1, 2, 3] | -+-----------------------------+ -``` - -#### Aliases - -- list_sort - -### `array_resize` - -Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. - -``` -array_resize(array, size, value) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **size**: New size of given array. -- **value**: Defines new elements' value or empty if value is not set. - -#### Example - -``` -> select array_resize([1, 2, 3], 5, 0); -+-------------------------------------+ -| array_resize(List([1,2,3],5,0)) | -+-------------------------------------+ -| [1, 2, 3, 0, 0] | -+-------------------------------------+ -``` - -#### Aliases - -- list_resize - -### `array_cat` - -_Alias of [array_concat](#array_concat)._ - -### `array_concat` - -Concatenates arrays. - -``` -array_concat(array[, ..., array_n]) -``` - -#### Arguments - -- **array**: Array expression to concatenate. - Can be a constant, column, or function, and any combination of array operators. -- **array_n**: Subsequent array column or literal array to concatenate. - -#### Example - -``` -> select array_concat([1, 2], [3, 4], [5, 6]); -+---------------------------------------------------+ -| array_concat(List([1,2]),List([3,4]),List([5,6])) | -+---------------------------------------------------+ -| [1, 2, 3, 4, 5, 6] | -+---------------------------------------------------+ -``` - -#### Aliases - -- array_cat -- list_cat -- list_concat - -### `array_contains` - -_Alias of [array_has](#array_has)._ - -### `array_has` - -Returns true if the array contains the element - -``` -array_has(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Scalar or Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Aliases - -- list_has - -### `array_has_all` - -Returns true if all elements of sub-array exist in array - -``` -array_has_all(array, sub-array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **sub-array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Aliases - -- list_has_all - -### `array_has_any` - -Returns true if any elements exist in both arrays - -``` -array_has_any(array, sub-array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **sub-array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Aliases - -- list_has_any - -### `array_dims` - -Returns an array of the array's dimensions. - -``` -array_dims(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_dims([[1, 2, 3], [4, 5, 6]]); -+---------------------------------+ -| array_dims(List([1,2,3,4,5,6])) | -+---------------------------------+ -| [2, 3] | -+---------------------------------+ -``` - -#### Aliases - -- list_dims - -### `array_distance` - -Returns the Euclidean distance between two input arrays of equal length. - -``` -array_distance(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_distance([1, 2], [1, 4]); -+------------------------------------+ -| array_distance(List([1,2], [1,4])) | -+------------------------------------+ -| 2.0 | -+------------------------------------+ -``` - -#### Aliases - -- list_distance - -### `array_distinct` - -Returns distinct values from the array after removing duplicates. - -``` -array_distinct(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_distinct([1, 3, 2, 3, 1, 2, 4]); -+---------------------------------+ -| array_distinct(List([1,2,3,4])) | -+---------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------+ -``` - -#### Aliases - -- list_distinct - -### `array_element` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -``` -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- array_extract -- list_element -- list_extract - -### `array_extract` - -_Alias of [array_element](#array_element)._ - -### `array_fill` - -Returns an array filled with copies of the given value. - -DEPRECATED: use `array_repeat` instead! - -``` -array_fill(element, array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to copy to the array. - -### `flatten` - -Converts an array of arrays to a flat array - -- Applies to any depth of nested arrays -- Does not change arrays that are already flat - -The flattened array contains all the elements from all source arrays. - -#### Arguments - -- **array**: Array expression - Can be a constant, column, or function, and any combination of array operators. - -``` -flatten(array) -``` - -### `array_indexof` - -_Alias of [array_position](#array_position)._ - -### `array_intersect` - -Returns an array of elements in the intersection of array1 and array2. - -``` -array_intersect(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [3, 4] | -+----------------------------------------------------+ -> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); -+----------------------------------------------------+ -| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); | -+----------------------------------------------------+ -| [] | -+----------------------------------------------------+ -``` - ---- - -#### Aliases - -- list_intersect - -### `array_join` - -_Alias of [array_to_string](#array_to_string)._ - -### `array_length` - -Returns the length of the array dimension. - -``` -array_length(array, dimension) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **dimension**: Array dimension. - -#### Example - -``` -> select array_length([1, 2, 3, 4, 5]); -+---------------------------------+ -| array_length(List([1,2,3,4,5])) | -+---------------------------------+ -| 5 | -+---------------------------------+ -``` - -#### Aliases - -- list_length - -### `array_ndims` - -Returns the number of dimensions of the array. - -``` -array_ndims(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_ndims([[1, 2, 3], [4, 5, 6]]); -+----------------------------------+ -| array_ndims(List([1,2,3,4,5,6])) | -+----------------------------------+ -| 2 | -+----------------------------------+ -``` - -#### Aliases - -- list_ndims - -### `array_prepend` - -Prepends an element to the beginning of an array. - -``` -array_prepend(element, array) -``` - -#### Arguments - -- **element**: Element to prepend to the array. -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_prepend(1, [2, 3, 4]); -+---------------------------------------+ -| array_prepend(Int64(1),List([2,3,4])) | -+---------------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------------+ -``` - -#### Aliases - -- array_push_front -- list_prepend -- list_push_front - -### `array_pop_front` - -Returns the array without the first element. - -``` -array_pop_front(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_pop_front([1, 2, 3]); -+-------------------------------+ -| array_pop_front(List([1,2,3])) | -+-------------------------------+ -| [2, 3] | -+-------------------------------+ -``` - -#### Aliases - -- list_pop_front - -### `array_pop_back` - -Returns the array without the last element. - -``` -array_pop_back(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_pop_back([1, 2, 3]); -+-------------------------------+ -| array_pop_back(List([1,2,3])) | -+-------------------------------+ -| [1, 2] | -+-------------------------------+ -``` - -#### Aliases - -- list_pop_back - -### `array_position` - -Returns the position of the first occurrence of the specified element in the array. - -``` -array_position(array, element) -array_position(array, element, index) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. -- **index**: Index at which to start searching. - -#### Example - -``` -> select array_position([1, 2, 2, 3, 1, 4], 2); -+----------------------------------------------+ -| array_position(List([1,2,2,3,1,4]),Int64(2)) | -+----------------------------------------------+ -| 2 | -+----------------------------------------------+ -``` - -#### Aliases - -- array_indexof -- list_indexof -- list_position - -### `array_positions` - -Searches for an element in the array, returns all occurrences. - -``` -array_positions(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for positions in the array. - -#### Example - -``` -> select array_positions([1, 2, 2, 3, 1, 4], 2); -+-----------------------------------------------+ -| array_positions(List([1,2,2,3,1,4]),Int64(2)) | -+-----------------------------------------------+ -| [2, 3] | -+-----------------------------------------------+ -``` - -#### Aliases - -- list_positions - -### `array_push_back` - -_Alias of [array_append](#array_append)._ - -### `array_push_front` - -_Alias of [array_prepend](#array_prepend)._ - -### `array_repeat` - -Returns an array containing element `count` times. - -``` -array_repeat(element, count) -``` - -#### Arguments - -- **element**: Element expression. - Can be a constant, column, or function, and any combination of array operators. -- **count**: Value of how many times to repeat the element. - -#### Example - -``` -> select array_repeat(1, 3); -+---------------------------------+ -| array_repeat(Int64(1),Int64(3)) | -+---------------------------------+ -| [1, 1, 1] | -+---------------------------------+ -``` - -``` -> select array_repeat([1, 2], 2); -+------------------------------------+ -| array_repeat(List([1,2]),Int64(2)) | -+------------------------------------+ -| [[1, 2], [1, 2]] | -+------------------------------------+ -``` - -#### Aliases - -- list_repeat - -### `array_remove` - -Removes the first element from the array equal to the given value. - -``` -array_remove(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -#### Example - -``` -> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); -+----------------------------------------------+ -| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | -+----------------------------------------------+ -| [1, 2, 3, 2, 1, 4] | -+----------------------------------------------+ -``` - -#### Aliases - -- list_remove - -### `array_remove_n` - -Removes the first `max` elements from the array equal to the given value. - -``` -array_remove_n(array, element, max) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. -- **max**: Number of first occurrences to remove. - -#### Example - -``` -> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); -+---------------------------------------------------------+ -| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | -+---------------------------------------------------------+ -| [1, 3, 2, 1, 4] | -+---------------------------------------------------------+ -``` - -#### Aliases - -- list_remove_n - -### `array_remove_all` - -Removes all elements from the array equal to the given value. - -``` -array_remove_all(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -#### Example - -``` -> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); -+--------------------------------------------------+ -| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | -+--------------------------------------------------+ -| [1, 3, 1, 4] | -+--------------------------------------------------+ -``` - -#### Aliases - -- list_remove_all - -### `array_replace` - -Replaces the first occurrence of the specified element with another specified element. - -``` -array_replace(array, from, to) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. - -#### Example - -``` -> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); -+--------------------------------------------------------+ -| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | -+--------------------------------------------------------+ -| [1, 5, 2, 3, 2, 1, 4] | -+--------------------------------------------------------+ -``` - -#### Aliases - -- list_replace - -### `array_replace_n` - -Replaces the first `max` occurrences of the specified element with another specified element. - -``` -array_replace_n(array, from, to, max) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. -- **max**: Number of first occurrences to replace. - -#### Example - -``` -> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); -+-------------------------------------------------------------------+ -| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | -+-------------------------------------------------------------------+ -| [1, 5, 5, 3, 2, 1, 4] | -+-------------------------------------------------------------------+ -``` - -#### Aliases - -- list_replace_n - -### `array_replace_all` - -Replaces all occurrences of the specified element with another specified element. - -``` -array_replace_all(array, from, to) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. - -#### Example - -``` -> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); -+------------------------------------------------------------+ -| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | -+------------------------------------------------------------+ -| [1, 5, 5, 3, 5, 1, 4] | -+------------------------------------------------------------+ -``` - -#### Aliases - -- list_replace_all - -### `array_reverse` - -Returns the array with the order of the elements reversed. - -``` -array_reverse(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_reverse([1, 2, 3, 4]); -+------------------------------------------------------------+ -| array_reverse(List([1, 2, 3, 4])) | -+------------------------------------------------------------+ -| [4, 3, 2, 1] | -+------------------------------------------------------------+ -``` - -#### Aliases - -- list_reverse - -### `array_slice` - -Returns a slice of the array based on 1-indexed start and end positions. - -``` -array_slice(array, begin, end) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **begin**: Index of the first element. - If negative, it counts backward from the end of the array. -- **end**: Index of the last element. - If negative, it counts backward from the end of the array. -- **stride**: Stride of the array slice. The default is 1. - -#### Example - -``` -> select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); -+--------------------------------------------------------+ -| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | -+--------------------------------------------------------+ -| [3, 4, 5, 6] | -+--------------------------------------------------------+ -``` - -#### Aliases - -- list_slice - -### `array_to_string` - -Converts each element to its text representation. - -``` -array_to_string(array, delimiter) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **delimiter**: Array element separator. - -#### Example - -``` -> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); -+----------------------------------------------------+ -| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | -+----------------------------------------------------+ -| 1,2,3,4,5,6,7,8 | -+----------------------------------------------------+ -``` - -#### Aliases - -- array_join -- list_join -- list_to_string - -### `array_union` - -Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. - -``` -array_union(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_union([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [1, 2, 3, 4, 5, 6] | -+----------------------------------------------------+ -> select array_union([1, 2, 3, 4], [5, 6, 7, 8]); -+----------------------------------------------------+ -| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | -+----------------------------------------------------+ -| [1, 2, 3, 4, 5, 6, 7, 8] | -+----------------------------------------------------+ -``` - ---- - -#### Aliases - -- list_union - -### `array_except` - -Returns an array of the elements that appear in the first array but not in the second. - -``` -array_except(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ -> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ -``` - ---- - -#### Aliases - -- list_except - -### `cardinality` - -Returns the total number of elements in the array. - -``` -cardinality(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); -+--------------------------------------+ -| cardinality(List([1,2,3,4,5,6,7,8])) | -+--------------------------------------+ -| 8 | -+--------------------------------------+ -``` - -### `empty` - -Returns 1 for an empty array or 0 for a non-empty array. - -``` -empty(array) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Example - -``` -> select empty([1]); -+------------------+ -| empty(List([1])) | -+------------------+ -| 0 | -+------------------+ -``` - -#### Aliases - -- array_empty, -- list_empty - -### `generate_series` - -Similar to the range function, but it includes the upper bound. - -``` -generate_series(start, stop, step) -``` - -#### Arguments - -- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. -- **end**: end of the series (included). Type must be the same as start. -- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. - -#### Example - -``` -> select generate_series(1,3); -+------------------------------------+ -| generate_series(Int64(1),Int64(3)) | -+------------------------------------+ -| [1, 2, 3] | -+------------------------------------+ -``` - -### `list_any_value` - -_Alias of [array_any_value](#array_any_value)._ - -### `list_append` - -_Alias of [array_append](#array_append)._ - -### `list_cat` - -_Alias of [array_concat](#array_concat)._ - -### `list_concat` - -_Alias of [array_concat](#array_concat)._ - -### `list_dims` - -_Alias of [array_dims](#array_dims)._ - -### `list_distance` - -_Alias of [array_distance](#array_distance)._ - -### `list_distinct` - -_Alias of [array_distinct](#array_distinct)._ - -### `list_element` - -_Alias of [array_element](#array_element)._ - -### `list_empty` - -_Alias of [empty](#empty)._ - -### `list_except` - -_Alias of [array_element](#array_except)._ - -### `list_extract` - -_Alias of [array_element](#array_element)._ - -### `list_has` - -_Alias of [array_has](#array_has)._ - -### `list_has_all` - -_Alias of [array_has_all](#array_has_all)._ - -### `list_has_any` - -_Alias of [array_has_any](#array_has_any)._ - -### `list_indexof` - -_Alias of [array_position](#array_position)._ - -### `list_intersect` - -_Alias of [array_position](#array_intersect)._ - -### `list_join` - -_Alias of [array_to_string](#array_to_string)._ - -### `list_length` - -_Alias of [array_length](#array_length)._ - -### `list_ndims` - -_Alias of [array_ndims](#array_ndims)._ - -### `list_prepend` - -_Alias of [array_prepend](#array_prepend)._ - -### `list_pop_back` - -_Alias of [array_pop_back](#array_pop_back)._ - -### `list_pop_front` - -_Alias of [array_pop_front](#array_pop_front)._ - -### `list_position` - -_Alias of [array_position](#array_position)._ - -### `list_positions` - -_Alias of [array_positions](#array_positions)._ - -### `list_push_back` - -_Alias of [array_append](#array_append)._ - -### `list_push_front` - -_Alias of [array_prepend](#array_prepend)._ - -### `list_repeat` - -_Alias of [array_repeat](#array_repeat)._ - -### `list_resize` - -_Alias of [array_resize](#array_resize)._ - -### `list_remove` - -_Alias of [array_remove](#array_remove)._ - -### `list_remove_n` - -_Alias of [array_remove_n](#array_remove_n)._ - -### `list_remove_all` - -_Alias of [array_remove_all](#array_remove_all)._ - -### `list_replace` - -_Alias of [array_replace](#array_replace)._ - -### `list_replace_n` - -_Alias of [array_replace_n](#array_replace_n)._ - -### `list_replace_all` - -_Alias of [array_replace_all](#array_replace_all)._ - -### `list_reverse` - -_Alias of [array_reverse](#array_reverse)._ - -### `list_slice` - -_Alias of [array_slice](#array_slice)._ - -### `list_sort` - -_Alias of [array_sort](#array_sort)._ - -### `list_to_string` +## Conditional Functions -_Alias of [array_to_string](#array_to_string)._ +See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) -### `list_union` +## String Functions -_Alias of [array_union](#array_union)._ +See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) -### `make_array` +### `position` -Returns an Arrow array using the specified input expressions. +Returns the position of `substr` in `origstr` (counting from 1). If `substr` does +not appear in `origstr`, return 0. ``` -make_array(expression1[, ..., expression_n]) +position(substr in origstr) ``` -### `array_empty` - -_Alias of [empty](#empty)._ - #### Arguments -- **expression_n**: Expression to include in the output array. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. - -#### Example - -``` -> select make_array(1, 2, 3, 4, 5); -+----------------------------------------------------------+ -| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | -+----------------------------------------------------------+ -| [1, 2, 3, 4, 5] | -+----------------------------------------------------------+ -``` - -#### Aliases - -- make_list +- **substr**: The pattern string. +- **origstr**: The model string. -### `make_list` +## Time and Date Functions -_Alias of [make_array](#make_array)._ +- [extract](#extract) -### `string_to_array` +### `extract` -Splits a string in to an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL. -`SELECT string_to_array('abc##def', '##')` or `SELECT string_to_array('abc def', ' ', 'def')` +Returns a sub-field from a time value as an integer. ``` -starts_with(str, delimiter[, null_str]) +extract(field FROM source) ``` -#### Arguments - -- **str**: String expression to split. -- **delimiter**: Delimiter string to split on. -- **null_str**: Substring values to be replaced with `NULL` - -#### Aliases - -- string_to_list - -### `string_to_list` - -_Alias of [string_to_array](#string_to_array)._ - -### `trim_array` - -Removes the last n elements from the array. - -DEPRECATED: use `array_slice` instead! +Equivalent to calling `date_part('field', source)`. For example, these are equivalent: +```sql +extract(day FROM '2024-04-13'::date) +date_part('day', '2024-04-13'::date) ``` -trim_array(array, n) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **n**: Element to trim the array. - -### `unnest` - -Transforms an array into rows. - -#### Arguments -- **array**: Array expression to unnest. - Can be a constant, column, or function, and any combination of array operators. - -#### Examples +See [date_part](#date_part). -``` -> select unnest(make_array(1, 2, 3, 4, 5)); -+------------------------------------------------------------------+ -| unnest(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5))) | -+------------------------------------------------------------------+ -| 1 | -| 2 | -| 3 | -| 4 | -| 5 | -+------------------------------------------------------------------+ -``` +## Array Functions -``` -> select unnest(range(0, 10)); -+-----------------------------------+ -| unnest(range(Int64(0),Int64(10))) | -+-----------------------------------+ -| 0 | -| 1 | -| 2 | -| 3 | -| 4 | -| 5 | -| 6 | -| 7 | -| 8 | -| 9 | -+-----------------------------------+ -``` +- [range](#range) ### `range` -Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or `SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or +`SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` The range start..end contains all values with start <= x < end. It is empty if start >= end. Step can not be 0 (then the range will be nonsense.). -Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, but when the required range is a date or timestamp, it must be 3 non-NULL parameters. +Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, +but when the required range is a date or timestamp, it must be 3 non-NULL parameters. For example, ``` @@ -3584,445 +122,6 @@ are not allowed - generate_series -## Struct Functions - -- [struct](#struct) -- [named_struct](#named_struct) -- [unnest](#unnest-struct) - -### `struct` - -Returns an Arrow struct using the specified input expressions optionally named. -Fields in the returned struct use the optional name or the `cN` naming convention. -For example: `c0`, `c1`, `c2`, etc. - -``` -struct(expression1[, ..., expression_n]) -``` - -For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `field_a` and `c1`: - -``` -select * from t; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -| 3 | 4 | -+---+---+ - --- use default names `c0`, `c1` -> select struct(a, b) from t; -+-----------------+ -| struct(t.a,t.b) | -+-----------------+ -| {c0: 1, c1: 2} | -| {c0: 3, c1: 4} | -+-----------------+ - --- name the first field `field_a` -select struct(a as field_a, b) from t; -+--------------------------------------------------+ -| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | -+--------------------------------------------------+ -| {field_a: 1, c1: 2} | -| {field_a: 3, c1: 4} | -+--------------------------------------------------+ -``` - -#### Arguments - -- **expression_n**: Expression to include in the output struct. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed . - -### `named_struct` - -Returns an Arrow struct using the specified name and input expressions pairs. - -``` -named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) -``` - -For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `field_a` and `field_b`: - -``` -select * from t; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -| 3 | 4 | -+---+---+ - -select named_struct('field_a', a, 'field_b', b) from t; -+-------------------------------------------------------+ -| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | -+-------------------------------------------------------+ -| {field_a: 1, field_b: 2} | -| {field_a: 3, field_b: 4} | -+-------------------------------------------------------+ -``` - -#### Arguments - -- **expression_n_name**: Name of the column field. - Must be a constant string. -- **expression_n_input**: Expression to include in the output struct. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. - -### `unnest (struct)` - -Unwraps struct fields into columns. - -#### Arguments - -- **struct**: Object expression to unnest. - Can be a constant, column, or function, and any combination of object operators. - -#### Examples - -``` -> select * from foo; -+---------------------+ -| column1 | -+---------------------+ -| {a: 5, b: a string} | -+---------------------+ - -> select unnest(column1) from foo; -+-----------------------+-----------------------+ -| unnest(foo.column1).a | unnest(foo.column1).b | -+-----------------------+-----------------------+ -| 5 | a string | -+-----------------------+-----------------------+ -``` - -## Map Functions - -- [map](#map) -- [make_map](#make_map) -- [map_extract](#map_extract) -- [map_keys](#map_keys) -- [map_values](#map_values) - -### `map` - -Returns an Arrow map with the specified key-value pairs. - -``` -map(key, value) -map(key: value) -``` - -#### Arguments - -- **key**: Expression to be used for key. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. -- **value**: Expression to be used for value. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); ----- -{POST: 41, HEAD: 33, PATCH: } - -SELECT MAP([[1,2], [3,4]], ['a', 'b']); ----- -{[1, 2]: a, [3, 4]: b} - -SELECT MAP { 'a': 1, 'b': 2 }; ----- -{a: 1, b: 2} -``` - -### `make_map` - -Returns an Arrow map with the specified key-value pairs. - -``` -make_map(key_1, value_1, ..., key_n, value_n) -``` - -#### Arguments - -- **key_n**: Expression to be used for key. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. -- **value_n**: Expression to be used for value. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); ----- -{POST: 41, HEAD: 33, PATCH: } -``` - -### `map_extract` - -Return a list containing the value for a given key or an empty list if the key is not contained in the map. - -``` -map_extract(map, key) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. -- `key`: Key to extract from the map. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); ----- -[1] -``` - -#### Aliases - -- element_at - -### `map_keys` - -Return a list of all keys in the map. - -``` -map_keys(map) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. - -#### Example - -``` -SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[a, b, c] - -select map_keys(map([100, 5], [42,43])); ----- -[100, 5] -``` - -### `map_values` - -Return a list of all values in the map. - -``` -map_values(map) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. - -#### Example - -``` -SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[1, , 3] - -select map_values(map([100, 5], [42,43])); ----- -[42, 43] -``` - -## Hashing Functions - -- [digest](#digest) -- [md5](#md5) -- [sha224](#sha224) -- [sha256](#sha256) -- [sha384](#sha384) -- [sha512](#sha512) - -### `digest` - -Computes the binary hash of an expression using the specified algorithm. - -``` -digest(expression, algorithm) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **algorithm**: String expression specifying algorithm to use. - Must be one of: - - - md5 - - sha224 - - sha256 - - sha384 - - sha512 - - blake2s - - blake2b - - blake3 - -### `md5` - -Computes an MD5 128-bit checksum for a string expression. - -``` -md5(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha224` - -Computes the SHA-224 hash of a binary string. - -``` -sha224(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha256` - -Computes the SHA-256 hash of a binary string. - -``` -sha256(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha384` - -Computes the SHA-384 hash of a binary string. - -``` -sha384(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha512` - -Computes the SHA-512 hash of a binary string. - -``` -sha512(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - ## Other Functions -- [arrow_cast](#arrow_cast) -- [arrow_typeof](#arrow_typeof) -- [version](#version) - -### `arrow_cast` - -Casts a value to a specific Arrow data type: - -``` -arrow_cast(expression, datatype) -``` - -#### Arguments - -- **expression**: Expression to cast. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. -- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name - to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] - -#### Example - -``` -> select arrow_cast(-5, 'Int8') as a, - arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, - arrow_cast('bar', 'LargeUtf8') as c, - arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d - ; -+----+-----+-----+---------------------------+ -| a | b | c | d | -+----+-----+-----+---------------------------+ -| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | -+----+-----+-----+---------------------------+ -1 row in set. Query took 0.001 seconds. -``` - -### `arrow_typeof` - -Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression: - -``` -arrow_typeof(expression) -``` - -#### Arguments - -- **expression**: Expression to evaluate. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. - -#### Example - -``` -> select arrow_typeof('foo'), arrow_typeof(1); -+---------------------------+------------------------+ -| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | -+---------------------------+------------------------+ -| Utf8 | Int64 | -+---------------------------+------------------------+ -1 row in set. Query took 0.001 seconds. -``` - -### `version` - -Returns the version of DataFusion. - -``` -version() -``` - -#### Example - -``` -> select version(); -+--------------------------------------------+ -| version() | -+--------------------------------------------+ -| Apache DataFusion 41.0.0, aarch64 on macos | -+--------------------------------------------+ -``` +See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md new file mode 100644 index 0000000000000..56173b97b4055 --- /dev/null +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -0,0 +1,4365 @@ + + + + +# Scalar Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +## Math Functions + +- [abs](#abs) +- [acos](#acos) +- [acosh](#acosh) +- [asin](#asin) +- [asinh](#asinh) +- [atan](#atan) +- [atan2](#atan2) +- [atanh](#atanh) +- [cbrt](#cbrt) +- [ceil](#ceil) +- [cos](#cos) +- [cosh](#cosh) +- [cot](#cot) +- [degrees](#degrees) +- [exp](#exp) +- [factorial](#factorial) +- [floor](#floor) +- [gcd](#gcd) +- [isnan](#isnan) +- [iszero](#iszero) +- [lcm](#lcm) +- [ln](#ln) +- [log](#log) +- [log10](#log10) +- [log2](#log2) +- [nanvl](#nanvl) +- [pi](#pi) +- [pow](#pow) +- [power](#power) +- [radians](#radians) +- [random](#random) +- [round](#round) +- [signum](#signum) +- [sin](#sin) +- [sinh](#sinh) +- [sqrt](#sqrt) +- [tan](#tan) +- [tanh](#tanh) +- [trunc](#trunc) + +### `abs` + +Returns the absolute value of a number. + +``` +abs(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `acos` + +Returns the arc cosine or inverse cosine of a number. + +``` +acos(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `acosh` + +Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number. + +``` +acosh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `asin` + +Returns the arc sine or inverse sine of a number. + +``` +asin(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `asinh` + +Returns the area hyperbolic sine or inverse hyperbolic sine of a number. + +``` +asinh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `atan` + +Returns the arc tangent or inverse tangent of a number. + +``` +atan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `atan2` + +Returns the arc tangent or inverse tangent of `expression_y / expression_x`. + +``` +atan2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Second numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `atanh` + +Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. + +``` +atanh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cbrt` + +Returns the cube root of a number. + +``` +cbrt(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `ceil` + +Returns the nearest integer greater than or equal to a number. + +``` +ceil(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cos` + +Returns the cosine of a number. + +``` +cos(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cosh` + +Returns the hyperbolic cosine of a number. + +``` +cosh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cot` + +Returns the cotangent of a number. + +``` +cot(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `degrees` + +Converts radians to degrees. + +``` +degrees(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `exp` + +Returns the base-e exponential of a number. + +``` +exp(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `factorial` + +Factorial. Returns 1 if value is less than 2. + +``` +factorial(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `floor` + +Returns the nearest integer less than or equal to a number. + +``` +floor(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `gcd` + +Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero. + +``` +gcd(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `isnan` + +Returns true if a given number is +NaN or -NaN otherwise returns false. + +``` +isnan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `iszero` + +Returns true if a given number is +0.0 or -0.0 otherwise returns false. + +``` +iszero(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `lcm` + +Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. + +``` +lcm(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `ln` + +Returns the natural logarithm of a number. + +``` +ln(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log` + +Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. + +``` +log(base, numeric_expression) +log(numeric_expression) +``` + +#### Arguments + +- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log10` + +Returns the base-10 logarithm of a number. + +``` +log10(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log2` + +Returns the base-2 logarithm of a number. + +``` +log2(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `nanvl` + +Returns the first argument if it's not _NaN_. +Returns the second argument otherwise. + +``` +nanvl(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. + +### `pi` + +Returns an approximate value of π. + +``` +pi() +``` + +### `pow` + +_Alias of [power](#power)._ + +### `power` + +Returns a base expression raised to the power of an exponent. + +``` +power(base, exponent) +``` + +#### Arguments + +- **base**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **exponent**: Exponent numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- pow + +### `radians` + +Converts degrees to radians. + +``` +radians(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `random` + +Returns a random float value in the range [0, 1). +The random seed is unique to each row. + +``` +random() +``` + +### `round` + +Rounds a number to the nearest integer. + +``` +round(numeric_expression[, decimal_places]) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **decimal_places**: Optional. The number of decimal places to round to. Defaults to 0. + +### `signum` + +Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`. + +``` +signum(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sin` + +Returns the sine of a number. + +``` +sin(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sinh` + +Returns the hyperbolic sine of a number. + +``` +sinh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sqrt` + +Returns the square root of a number. + +``` +sqrt(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `tan` + +Returns the tangent of a number. + +``` +tan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `tanh` + +Returns the hyperbolic tangent of a number. + +``` +tanh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `trunc` + +Truncates a number to a whole number or truncated to the specified decimal places. + +``` +trunc(numeric_expression[, decimal_places]) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **decimal_places**: Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`. + +## Conditional Functions + +- [coalesce](#coalesce) +- [ifnull](#ifnull) +- [nullif](#nullif) +- [nvl](#nvl) +- [nvl2](#nvl2) + +### `coalesce` + +Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. + +``` +coalesce(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. + +#### Example + +```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +``` + +### `ifnull` + +_Alias of [nvl](#nvl)._ + +### `nullif` + +Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce). + +``` +nullif(expression1, expression2) +``` + +#### Arguments + +- **expression1**: Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +``` + +### `nvl` + +Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_. + +``` +nvl(expression1, expression2) +``` + +#### Arguments + +- **expression1**: Expression to return if not null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` + +#### Aliases + +- ifnull + +### `nvl2` + +Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. + +``` +nvl2(expression1, expression2, expression3) +``` + +#### Arguments + +- **expression1**: Expression to test for null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators. +- **expression3**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` + +## String Functions + +- [ascii](#ascii) +- [bit_length](#bit_length) +- [btrim](#btrim) +- [char_length](#char_length) +- [character_length](#character_length) +- [chr](#chr) +- [concat](#concat) +- [concat_ws](#concat_ws) +- [contains](#contains) +- [ends_with](#ends_with) +- [find_in_set](#find_in_set) +- [initcap](#initcap) +- [instr](#instr) +- [left](#left) +- [length](#length) +- [levenshtein](#levenshtein) +- [lower](#lower) +- [lpad](#lpad) +- [ltrim](#ltrim) +- [octet_length](#octet_length) +- [position](#position) +- [repeat](#repeat) +- [replace](#replace) +- [reverse](#reverse) +- [right](#right) +- [rpad](#rpad) +- [rtrim](#rtrim) +- [split_part](#split_part) +- [starts_with](#starts_with) +- [strpos](#strpos) +- [substr](#substr) +- [substr_index](#substr_index) +- [substring](#substring) +- [substring_index](#substring_index) +- [to_hex](#to_hex) +- [translate](#translate) +- [trim](#trim) +- [upper](#upper) +- [uuid](#uuid) + +### `ascii` + +Returns the Unicode character code of the first character in a string. + +``` +ascii(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +``` + +**Related functions**: + +- [chr](#chr) + +### `bit_length` + +Returns the bit length of a string. + +``` +bit_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +``` + +**Related functions**: + +- [length](#length) +- [octet_length](#octet_length) + +### `btrim` + +Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. + +``` +btrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ + +#### Example + +```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(BOTH trim_str FROM str) +``` + +```sql +trim(trim_str FROM str) +``` + +#### Aliases + +- trim + +**Related functions**: + +- [ltrim](#ltrim) +- [rtrim](#rtrim) + +### `char_length` + +_Alias of [character_length](#character_length)._ + +### `character_length` + +Returns the number of characters in a string. + +``` +character_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +``` + +#### Aliases + +- length +- char_length + +**Related functions**: + +- [bit_length](#bit_length) +- [octet_length](#octet_length) + +### `chr` + +Returns the character with the specified ASCII or Unicode code value. + +``` +chr(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +``` + +**Related functions**: + +- [ascii](#ascii) + +### `concat` + +Concatenates multiple strings together. + +``` +concat(str[, ..., str_n]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +``` + +**Related functions**: + +- [concat_ws](#concat_ws) + +### `concat_ws` + +Concatenates multiple strings together with a specified separator. + +``` +concat_ws(separator, str[, ..., str_n]) +``` + +#### Arguments + +- **separator**: Separator to insert between concatenated strings. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +``` + +**Related functions**: + +- [concat](#concat) + +### `contains` + +Return true if search_str is found within string (case-sensitive). + +``` +contains(str, search_str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **search_str**: The string to search for in str. + +#### Example + +```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +``` + +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +``` + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + +#### Example + +```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +``` + +### `initcap` + +Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. + +``` +initcap(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +``` + +**Related functions**: + +- [lower](#lower) +- [upper](#upper) + +### `instr` + +_Alias of [strpos](#strpos)._ + +### `left` + +Returns a specified number of characters from the left side of a string. + +``` +left(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return. + +#### Example + +```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +``` + +**Related functions**: + +- [right](#right) + +### `length` + +_Alias of [character_length](#character_length)._ + +### `levenshtein` + +Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings. + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +#### Example + +```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +``` + +### `lower` + +Converts a string to lower-case. + +``` +lower(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [upper](#upper) + +### `lpad` + +Pads the left side of a string with another string to a specified string length. + +``` +lpad(str, n[, padding_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: String length to pad to. +- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +``` + +**Related functions**: + +- [rpad](#rpad) + +### `ltrim` + +Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. + +``` +ltrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(LEADING trim_str FROM str) +``` + +**Related functions**: + +- [btrim](#btrim) +- [rtrim](#rtrim) + +### `octet_length` + +Returns the length of a string in bytes. + +``` +octet_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +``` + +**Related functions**: + +- [bit_length](#bit_length) +- [length](#length) + +### `position` + +_Alias of [strpos](#strpos)._ + +### `repeat` + +Returns a string with an input string repeated a specified number. + +``` +repeat(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of times to repeat the input string. + +#### Example + +```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +``` + +### `replace` + +Replaces all occurrences of a specified substring in a string with a new substring. + +``` +replace(str, substr, replacement) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to replace in the input string. Substring expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +``` + +### `reverse` + +Reverses the character order of a string. + +``` +reverse(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +``` + +### `right` + +Returns a specified number of characters from the right side of a string. + +``` +right(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return + +#### Example + +```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +``` + +**Related functions**: + +- [left](#left) + +### `rpad` + +Pads the right side of a string with another string to a specified string length. + +``` +rpad(str, n[, padding_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: String length to pad to. +- **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +``` + +**Related functions**: + +- [lpad](#lpad) + +### `rtrim` + +Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. + +``` +rtrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(TRAILING trim_str FROM str) +``` + +**Related functions**: + +- [btrim](#btrim) +- [ltrim](#ltrim) + +### `split_part` + +Splits a string based on a specified delimiter and returns the substring in the specified position. + +``` +split_part(str, delimiter, pos) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delimiter**: String or character to split on. +- **pos**: Position of the part to return. + +#### Example + +```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +``` + +### `starts_with` + +Tests if a string starts with a substring. + +``` +starts_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +``` + +### `strpos` + +Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0. + +``` +strpos(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to search for. + +#### Example + +```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +``` + +#### Alternative Syntax + +```sql +position(substr in origstr) +``` + +#### Aliases + +- instr +- position + +### `substr` + +Extracts a substring of a specified number of characters from a specific starting position in a string. + +``` +substr(str, start_pos[, length]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start_pos**: Character position to start the substring at. The first character in the string has a position of 1. +- **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. + +#### Example + +```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +substring(str from start_pos for length) +``` + +#### Aliases + +- substring + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delim**: The string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be either a positive or negative number. + +#### Example + +```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +``` + +#### Aliases + +- substring_index + +### `substring` + +_Alias of [substr](#substr)._ + +### `substring_index` + +_Alias of [substr_index](#substr_index)._ + +### `to_hex` + +Converts an integer to a hexadecimal string. + +``` +to_hex(int) +``` + +#### Arguments + +- **int**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +``` + +### `translate` + +Translates characters in a string to specified translation characters. + +``` +translate(str, chars, translation) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **chars**: Characters to translate. +- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. + +#### Example + +```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +``` + +### `trim` + +_Alias of [btrim](#btrim)._ + +### `upper` + +Converts a string to upper-case. + +``` +upper(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [lower](#lower) + +### `uuid` + +Returns [`UUID v4`]() string value which is unique per row. + +``` +uuid() +``` + +#### Example + +```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +``` + +## Binary String Functions + +- [decode](#decode) +- [encode](#encode) + +### `decode` + +Decode binary data from textual representation in string. + +``` +decode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing encoded string data +- **format**: Same arguments as [encode](#encode) + +**Related functions**: + +- [encode](#encode) + +### `encode` + +Encode binary data into a textual representation. + +``` +encode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing string or binary data +- **format**: Supported formats are: `base64`, `hex` + +**Related functions**: + +- [decode](#decode) + +## Regular Expression Functions + +Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) +(minus support for several features including look-around and backreferences). +The following regular expression functions are supported: + +- [regexp_count](#regexp_count) +- [regexp_like](#regexp_like) +- [regexp_match](#regexp_match) +- [regexp_replace](#regexp_replace) + +### `regexp_count` + +Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +``` + +### `regexp_like` + +Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. + +``` +regexp_like(str, regexp[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); ++--------------------------------------------------------+ +| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | ++--------------------------------------------------------+ +| true | ++--------------------------------------------------------+ +SELECT regexp_like('aBc', '(b|d)', 'i'); ++--------------------------------------------------+ +| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +### `regexp_match` + +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. + +``` +regexp_match(str, regexp[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +### `regexp_replace` + +Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). + +``` +regexp_replace(str, regexp, replacement[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **replacement**: Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +## Time and Date Functions + +- [current_date](#current_date) +- [current_time](#current_time) +- [current_timestamp](#current_timestamp) +- [date_bin](#date_bin) +- [date_format](#date_format) +- [date_part](#date_part) +- [date_trunc](#date_trunc) +- [datepart](#datepart) +- [datetrunc](#datetrunc) +- [from_unixtime](#from_unixtime) +- [make_date](#make_date) +- [now](#now) +- [to_char](#to_char) +- [to_date](#to_date) +- [to_local_time](#to_local_time) +- [to_timestamp](#to_timestamp) +- [to_timestamp_micros](#to_timestamp_micros) +- [to_timestamp_millis](#to_timestamp_millis) +- [to_timestamp_nanos](#to_timestamp_nanos) +- [to_timestamp_seconds](#to_timestamp_seconds) +- [to_unixtime](#to_unixtime) +- [today](#today) + +### `current_date` + +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. + +``` +current_date() +``` + +#### Aliases + +- today + +### `current_time` + +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. + +``` +current_time() +``` + +### `current_timestamp` + +_Alias of [now](#now)._ + +### `date_bin` + +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. + +``` +date_bin(interval, expression, origin-timestamp) +``` + +#### Arguments + +- **interval**: Bin interval. +- **expression**: Time expression to operate on. Can be a constant, column, or function. +- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century + +### `date_format` + +_Alias of [to_char](#to_char)._ + +### `date_part` + +Returns the specified part of the date as an integer. + +``` +date_part(part, expression) +``` + +#### Arguments + +- **part**: Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Alternative Syntax + +```sql +extract(field FROM source) +``` + +#### Aliases + +- datepart + +### `date_trunc` + +Truncates a timestamp value to a specified precision. + +``` +date_trunc(precision, expression) +``` + +#### Arguments + +- **precision**: Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Aliases + +- datetrunc + +### `datepart` + +_Alias of [date_part](#date_part)._ + +### `datetrunc` + +_Alias of [date_trunc](#date_trunc)._ + +### `from_unixtime` + +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. + +``` +from_unixtime(expression) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. + +### `make_date` + +Make a date from year/month/day component parts. + +``` +make_date(year, month, day) +``` + +#### Arguments + +- **year**: Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **month**: Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) + +### `now` + +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. + +``` +now() +``` + +#### Aliases + +- current_timestamp + +### `to_char` + +Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. + +``` +to_char(expression, format) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration. +- **format**: A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) + +#### Aliases + +- date_format + +### `to_date` + +Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. + +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. + +``` +to_date('2017-05-31', '%Y-%m-%d') +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. + +#### Example + +```sql +> select to_date('2023-01-31'); ++-----------------------------+ +| to_date(Utf8("2023-01-31")) | ++-----------------------------+ +| 2023-01-31 | ++-----------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) + +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + +### `to_timestamp` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + +``` +to_timestamp(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_micros` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp. + +``` +to_timestamp_micros(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_millis` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_millis(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_nanos` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_nanos(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_seconds` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_seconds(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_unixtime` + +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. + +``` +to_unixtime(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` + +### `today` + +_Alias of [current_date](#current_date)._ + +## Array Functions + +- [array_any_value](#array_any_value) +- [array_append](#array_append) +- [array_cat](#array_cat) +- [array_concat](#array_concat) +- [array_contains](#array_contains) +- [array_dims](#array_dims) +- [array_distance](#array_distance) +- [array_distinct](#array_distinct) +- [array_element](#array_element) +- [array_empty](#array_empty) +- [array_except](#array_except) +- [array_extract](#array_extract) +- [array_has](#array_has) +- [array_has_all](#array_has_all) +- [array_has_any](#array_has_any) +- [array_indexof](#array_indexof) +- [array_intersect](#array_intersect) +- [array_join](#array_join) +- [array_length](#array_length) +- [array_ndims](#array_ndims) +- [array_pop_back](#array_pop_back) +- [array_pop_front](#array_pop_front) +- [array_position](#array_position) +- [array_positions](#array_positions) +- [array_prepend](#array_prepend) +- [array_push_back](#array_push_back) +- [array_push_front](#array_push_front) +- [array_remove](#array_remove) +- [array_remove_all](#array_remove_all) +- [array_remove_n](#array_remove_n) +- [array_repeat](#array_repeat) +- [array_replace](#array_replace) +- [array_replace_all](#array_replace_all) +- [array_replace_n](#array_replace_n) +- [array_resize](#array_resize) +- [array_reverse](#array_reverse) +- [array_slice](#array_slice) +- [array_sort](#array_sort) +- [array_to_string](#array_to_string) +- [array_union](#array_union) +- [cardinality](#cardinality) +- [empty](#empty) +- [flatten](#flatten) +- [generate_series](#generate_series) +- [list_any_value](#list_any_value) +- [list_append](#list_append) +- [list_cat](#list_cat) +- [list_concat](#list_concat) +- [list_contains](#list_contains) +- [list_dims](#list_dims) +- [list_distance](#list_distance) +- [list_distinct](#list_distinct) +- [list_element](#list_element) +- [list_empty](#list_empty) +- [list_except](#list_except) +- [list_extract](#list_extract) +- [list_has](#list_has) +- [list_has_all](#list_has_all) +- [list_has_any](#list_has_any) +- [list_indexof](#list_indexof) +- [list_intersect](#list_intersect) +- [list_join](#list_join) +- [list_length](#list_length) +- [list_ndims](#list_ndims) +- [list_pop_back](#list_pop_back) +- [list_pop_front](#list_pop_front) +- [list_position](#list_position) +- [list_positions](#list_positions) +- [list_prepend](#list_prepend) +- [list_push_back](#list_push_back) +- [list_push_front](#list_push_front) +- [list_remove](#list_remove) +- [list_remove_all](#list_remove_all) +- [list_remove_n](#list_remove_n) +- [list_repeat](#list_repeat) +- [list_replace](#list_replace) +- [list_replace_all](#list_replace_all) +- [list_replace_n](#list_replace_n) +- [list_resize](#list_resize) +- [list_reverse](#list_reverse) +- [list_slice](#list_slice) +- [list_sort](#list_sort) +- [list_to_string](#list_to_string) +- [list_union](#list_union) +- [make_array](#make_array) +- [make_list](#make_list) +- [range](#range) +- [string_to_array](#string_to_array) +- [string_to_list](#string_to_list) + +### `array_any_value` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_any_value + +### `array_append` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- list_append +- array_push_back +- list_push_back + +### `array_cat` + +_Alias of [array_concat](#array_concat)._ + +### `array_concat` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- array_cat +- list_concat +- list_cat + +### `array_contains` + +_Alias of [array_has](#array_has)._ + +### `array_dims` + +Returns an array of the array's dimensions. + +``` +array_dims(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +``` + +#### Aliases + +- list_dims + +### `array_distance` + +Returns the Euclidean distance between two input arrays of equal length. + +``` +array_distance(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +``` + +#### Aliases + +- list_distance + +### `array_distinct` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_distinct + +### `array_element` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- array_extract +- list_element +- list_extract + +### `array_empty` + +_Alias of [empty](#empty)._ + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_except + +### `array_extract` + +_Alias of [array_element](#array_element)._ + +### `array_has` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has +- array_contains +- list_contains + +### `array_has_all` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has_all + +### `array_has_any` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has_any + +### `array_indexof` + +_Alias of [array_position](#array_position)._ + +### `array_intersect` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_intersect + +### `array_join` + +_Alias of [array_to_string](#array_to_string)._ + +### `array_length` + +Returns the length of the array dimension. + +``` +array_length(array, dimension) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **dimension**: Array dimension. + +#### Example + +```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ +``` + +#### Aliases + +- list_length + +### `array_ndims` + +Returns an array of the array's dimensions. + +``` +array_dims(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +``` + +#### Aliases + +- list_ndims + +### `array_pop_back` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_pop_back + +### `array_pop_front` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_pop_front + +### `array_position` + +Returns the position of the first occurrence of the specified element in the array. + +``` +array_position(array, element) +array_position(array, element, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. + +#### Example + +```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_position +- array_indexof +- list_indexof + +### `array_positions` + +Returns the position of the first occurrence of the specified element in the array. + +``` +array_position(array, element) +array_position(array, element, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. + +#### Example + +```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_positions + +### `array_prepend` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- list_prepend +- array_push_front +- list_push_front + +### `array_push_back` + +_Alias of [array_append](#array_append)._ + +### `array_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `array_remove` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove + +### `array_remove_all` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove_all + +### `array_remove_n` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove_n + +### `array_repeat` + +Returns an array containing element `count` times. + +``` +array_repeat(element, count) +``` + +#### Arguments + +- **element**: Element expression. Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. + +#### Example + +```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +``` + +#### Aliases + +- list_repeat + +### `array_replace` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace + +### `array_replace_all` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_all + +### `array_replace_n` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_n + +### `array_resize` + +Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. + +``` +array_resize(array, size, value) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **size**: New size of given array. +- **value**: Defines new elements' value or empty if value is not set. + +#### Example + +```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ +``` + +#### Aliases + +- list_resize + +### `array_reverse` + +Returns the array with the order of the elements reversed. + +``` +array_reverse(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_reverse([1, 2, 3, 4]); ++------------------------------------------------------------+ +| array_reverse(List([1, 2, 3, 4])) | ++------------------------------------------------------------+ +| [4, 3, 2, 1] | ++------------------------------------------------------------+ +``` + +#### Aliases + +- list_reverse + +### `array_slice` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_slice + +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + +### `array_to_string` + +Converts each element to its text representation. + +``` +array_to_string(array, delimiter) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. + +#### Example + +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_to_string +- array_join +- list_join + +### `array_union` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_union + +### `cardinality` + +Returns the total number of elements in the array. + +``` +cardinality(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +``` + +### `empty` + +Returns 1 for an empty array or 0 for a non-empty array. + +``` +empty(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +``` + +#### Aliases + +- array_empty +- list_empty + +### `flatten` + +Converts an array of arrays to a flat array. + +- Applies to any depth of nested arrays +- Does not change arrays that are already flat + +The flattened array contains all the elements from all source arrays. + +``` +flatten(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +``` + +### `generate_series` + +Similar to the range function, but it includes the upper bound. + +``` +generate_series(start, stop, step) +``` + +#### Arguments + +- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: end of the series (included). Type must be the same as start. +- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. + +#### Example + +```sql +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ +``` + +### `list_any_value` + +_Alias of [array_any_value](#array_any_value)._ + +### `list_append` + +_Alias of [array_append](#array_append)._ + +### `list_cat` + +_Alias of [array_concat](#array_concat)._ + +### `list_concat` + +_Alias of [array_concat](#array_concat)._ + +### `list_contains` + +_Alias of [array_has](#array_has)._ + +### `list_dims` + +_Alias of [array_dims](#array_dims)._ + +### `list_distance` + +_Alias of [array_distance](#array_distance)._ + +### `list_distinct` + +_Alias of [array_distinct](#array_distinct)._ + +### `list_element` + +_Alias of [array_element](#array_element)._ + +### `list_empty` + +_Alias of [empty](#empty)._ + +### `list_except` + +_Alias of [array_except](#array_except)._ + +### `list_extract` + +_Alias of [array_element](#array_element)._ + +### `list_has` + +_Alias of [array_has](#array_has)._ + +### `list_has_all` + +_Alias of [array_has_all](#array_has_all)._ + +### `list_has_any` + +_Alias of [array_has_any](#array_has_any)._ + +### `list_indexof` + +_Alias of [array_position](#array_position)._ + +### `list_intersect` + +_Alias of [array_intersect](#array_intersect)._ + +### `list_join` + +_Alias of [array_to_string](#array_to_string)._ + +### `list_length` + +_Alias of [array_length](#array_length)._ + +### `list_ndims` + +_Alias of [array_ndims](#array_ndims)._ + +### `list_pop_back` + +_Alias of [array_pop_back](#array_pop_back)._ + +### `list_pop_front` + +_Alias of [array_pop_front](#array_pop_front)._ + +### `list_position` + +_Alias of [array_position](#array_position)._ + +### `list_positions` + +_Alias of [array_positions](#array_positions)._ + +### `list_prepend` + +_Alias of [array_prepend](#array_prepend)._ + +### `list_push_back` + +_Alias of [array_append](#array_append)._ + +### `list_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `list_remove` + +_Alias of [array_remove](#array_remove)._ + +### `list_remove_all` + +_Alias of [array_remove_all](#array_remove_all)._ + +### `list_remove_n` + +_Alias of [array_remove_n](#array_remove_n)._ + +### `list_repeat` + +_Alias of [array_repeat](#array_repeat)._ + +### `list_replace` + +_Alias of [array_replace](#array_replace)._ + +### `list_replace_all` + +_Alias of [array_replace_all](#array_replace_all)._ + +### `list_replace_n` + +_Alias of [array_replace_n](#array_replace_n)._ + +### `list_resize` + +_Alias of [array_resize](#array_resize)._ + +### `list_reverse` + +_Alias of [array_reverse](#array_reverse)._ + +### `list_slice` + +_Alias of [array_slice](#array_slice)._ + +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + +### `list_to_string` + +_Alias of [array_to_string](#array_to_string)._ + +### `list_union` + +_Alias of [array_union](#array_union)._ + +### `make_array` + +Returns an array using the specified input expressions. + +``` +make_array(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression_n**: Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators. + +#### Example + +```sql +> select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +``` + +#### Aliases + +- make_list + +### `make_list` + +_Alias of [make_array](#make_array)._ + +### `range` + +Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0. + +``` +range(start, stop, step) +``` + +#### Arguments + +- **start**: Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: End of the range (not included). Type must be the same as start. +- **step**: Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges. + +#### Example + +```sql +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ + +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ +``` + +### `string_to_array` + +Converts each element to its text representation. + +``` +array_to_string(array, delimiter) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. + +#### Example + +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` + +#### Aliases + +- string_to_list + +### `string_to_list` + +_Alias of [string_to_array](#string_to_array)._ + +## Struct Functions + +- [named_struct](#named_struct) +- [row](#row) +- [struct](#struct) + +### `named_struct` + +Returns an Arrow struct using the specified name and input expressions pairs. + +``` +named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) +``` + +#### Arguments + +- **expression_n_name**: Name of the column field. Must be a constant string. +- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators. + +#### Example + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: + +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` + +### `row` + +_Alias of [struct](#struct)._ + +### `struct` + +Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc. + +``` +struct(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators. + +#### Example + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `c1`: + +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +-- use default names `c0`, `c1` +> select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ + +-- name the first field `field_a` +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ +``` + +#### Aliases + +- row + +## Map Functions + +- [element_at](#element_at) +- [map](#map) +- [map_extract](#map_extract) +- [map_keys](#map_keys) +- [map_values](#map_values) + +### `element_at` + +_Alias of [map_extract](#map_extract)._ + +### `map` + +Returns an Arrow map with the specified key-value pairs. + +The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null. + +``` +map(key, value) +map(key: value) +make_map(['key1', 'key2'], ['value1', 'value2']) +``` + +#### Arguments + +- **key**: For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null. +- **value**: For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of values to be mapped to the corresponding keys. + +#### Example + +````sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ``` + + +### `map_extract` + +Returns a list containing the value for the given key or an empty list if the key is not present in the map. + +```` + +map_extract(map, key) + +```` +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. +- **key**: Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed. + +#### Example + +```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```` + +#### Aliases + +- element_at + +### `map_keys` + +Returns a list of all keys in the map. + +``` +map_keys(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +``` + +### `map_values` + +Returns a list of all values in the map. + +``` +map_values(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +``` + +## Hashing Functions + +- [digest](#digest) +- [md5](#md5) +- [sha224](#sha224) +- [sha256](#sha256) +- [sha384](#sha384) +- [sha512](#sha512) + +### `digest` + +Computes the binary hash of an expression using the specified algorithm. + +``` +digest(expression, algorithm) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **algorithm**: String expression specifying algorithm to use. Must be one of: +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3 + +#### Example + +```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `md5` + +Computes an MD5 128-bit checksum for a string expression. + +``` +md5(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +``` + +### `sha224` + +Computes the SHA-224 hash of a binary string. + +``` +sha224(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `sha256` + +Computes the SHA-256 hash of a binary string. + +``` +sha256(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +``` + +### `sha384` + +Computes the SHA-384 hash of a binary string. + +``` +sha384(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +``` + +### `sha512` + +Computes the SHA-512 hash of a binary string. + +``` +sha512(expression) +``` + +#### Arguments + +- **expression**: String + +#### Example + +```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +``` + +## Other Functions + +- [arrow_cast](#arrow_cast) +- [arrow_typeof](#arrow_typeof) +- [get_field](#get_field) +- [version](#version) + +### `arrow_cast` + +Casts a value to a specific Arrow data type. + +``` +arrow_cast(expression, datatype) +``` + +#### Arguments + +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] + +#### Example + +```sql +> select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +``` + +### `arrow_typeof` + +Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression. + +``` +arrow_typeof(expression) +``` + +#### Arguments + +- **expression**: Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +``` + +### `get_field` + +Returns a field within a map or a struct with the given key. +Note: most users invoke `get_field` indirectly via field access +syntax such as `my_struct_col['field_name']` which results in a call to +`get_field(my_struct_col, 'field_name')`. + +``` +get_field(expression1, expression2) +``` + +#### Arguments + +- **expression1**: The map or struct to retrieve a field for. +- **expression2**: The field name in the map or struct to retrieve data for. Must evaluate to a string. + +#### Example + +```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +``` + +### `version` + +Returns the version of DataFusion. + +``` +version() +``` + +#### Example + +```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/special_functions.md b/docs/source/user-guide/sql/special_functions.md new file mode 100644 index 0000000000000..7c9efbb66218f --- /dev/null +++ b/docs/source/user-guide/sql/special_functions.md @@ -0,0 +1,100 @@ + + +# Special Functions + +## Expansion Functions + +- [unnest](#unnest) +- [unnest(struct)](#unnest-struct) + +### `unnest` + +Expands an array or map into rows. + +#### Arguments + +- **array**: Array expression to unnest. + Can be a constant, column, or function, and any combination of array operators. + +#### Examples + +```sql +> select unnest(make_array(1, 2, 3, 4, 5)) as unnested; ++----------+ +| unnested | ++----------+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | ++----------+ +``` + +```sql +> select unnest(range(0, 10)) as unnested_range; ++----------------+ +| unnested_range | ++----------------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----------------+ +``` + +### `unnest (struct)` + +Expand a struct fields into individual columns. + +#### Arguments + +- **struct**: Object expression to unnest. + Can be a constant, column, or function, and any combination of object operators. + +#### Examples + +```sql +> create table foo as values ({a: 5, b: 'a string'}), ({a:6, b: 'another string'}); + +> create view foov as select column1 as struct_column from foo; + +> select * from foov; ++---------------------------+ +| struct_column | ++---------------------------+ +| {a: 5, b: a string} | +| {a: 6, b: another string} | ++---------------------------+ + +> select unnest(struct_column) from foov; ++------------------------------------------+------------------------------------------+ +| unnest_placeholder(foov.struct_column).a | unnest_placeholder(foov.struct_column).b | ++------------------------------------------+------------------------------------------+ +| 5 | a string | +| 6 | another string | ++------------------------------------------+------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 4d9d2557249f1..8216a3b258b8c 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -19,7 +19,15 @@ # Window Functions -A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. This is comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (new)](window_functions_new.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +Window functions are comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result Here is an example that shows how to compare each employee's salary with the average salary in his or her department: @@ -138,103 +146,12 @@ RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must All [aggregate functions](aggregate_functions.md) can be used as window functions. -## Ranking functions - -- [row_number](#row_number) -- [rank](#rank) -- [dense_rank](#dense_rank) -- [ntile](#ntile) - -### `row_number` - -Number of the current row within its partition, counting from 1. - -```sql -row_number() -``` - -### `rank` - -Rank of the current row with gaps; same as row_number of its first peer. - -```sql -rank() -``` - -### `dense_rank` - -Rank of the current row without gaps; this function counts peer groups. - -```sql -dense_rank() -``` - -### `ntile` - -Integer ranging from 1 to the argument value, dividing the partition as equally as possible. - -```sql -ntile(expression) -``` - -#### Arguments - -- **expression**: An integer describing the number groups the partition should be split into - ## Analytical functions -- [cume_dist](#cume_dist) -- [percent_rank](#percent_rank) -- [lag](#lag) -- [lead](#lead) - [first_value](#first_value) - [last_value](#last_value) - [nth_value](#nth_value) -### `cume_dist` - -Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). - -```sql -cume_dist() -``` - -### `percent_rank` - -Relative rank of the current row: (rank - 1) / (total rows - 1). - -```sql -percent_rank() -``` - -### `lag` - -Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lag(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows back the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - -### `lead` - -Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lead(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows forward the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - ### `first_value` Returns value evaluated at the row that is the first row of the window frame. diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md new file mode 100644 index 0000000000000..ae3edb832fcb1 --- /dev/null +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -0,0 +1,250 @@ + + + + +# Window Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result + +Here is an example that shows how to compare each employee's salary with the average salary in his or her department: + +```sql +SELECT depname, empno, salary, avg(salary) OVER (PARTITION BY depname) FROM empsalary; + ++-----------+-------+--------+-------------------+ +| depname | empno | salary | avg | ++-----------+-------+--------+-------------------+ +| personnel | 2 | 3900 | 3700.0 | +| personnel | 5 | 3500 | 3700.0 | +| develop | 8 | 6000 | 5020.0 | +| develop | 10 | 5200 | 5020.0 | +| develop | 11 | 5200 | 5020.0 | +| develop | 9 | 4500 | 5020.0 | +| develop | 7 | 4200 | 5020.0 | +| sales | 1 | 5000 | 4866.666666666667 | +| sales | 4 | 4800 | 4866.666666666667 | +| sales | 3 | 4800 | 4866.666666666667 | ++-----------+-------+--------+-------------------+ +``` + +A window function call always contains an OVER clause directly following the window function's name and argument(s). This is what syntactically distinguishes it from a normal function or non-window aggregate. The OVER clause determines exactly how the rows of the query are split up for processing by the window function. The PARTITION BY clause within OVER divides the rows into groups, or partitions, that share the same values of the PARTITION BY expression(s). For each row, the window function is computed across the rows that fall into the same partition as the current row. The previous example showed how to count the average of a column per partition. + +You can also control the order in which rows are processed by window functions using ORDER BY within OVER. (The window ORDER BY does not even have to match the order in which the rows are output.) Here is an example: + +```sql +SELECT depname, empno, salary, + rank() OVER (PARTITION BY depname ORDER BY salary DESC) +FROM empsalary; + ++-----------+-------+--------+--------+ +| depname | empno | salary | rank | ++-----------+-------+--------+--------+ +| personnel | 2 | 3900 | 1 | +| develop | 8 | 6000 | 1 | +| develop | 10 | 5200 | 2 | +| develop | 11 | 5200 | 2 | +| develop | 9 | 4500 | 4 | +| develop | 7 | 4200 | 5 | +| sales | 1 | 5000 | 1 | +| sales | 4 | 4800 | 2 | +| personnel | 5 | 3500 | 2 | +| sales | 3 | 4800 | 2 | ++-----------+-------+--------+--------+ +``` + +There is another important concept associated with window functions: for each row, there is a set of rows within its partition called its window frame. Some window functions act only on the rows of the window frame, rather than of the whole partition. Here is an example of using window frames in queries: + +```sql +SELECT depname, empno, salary, + avg(salary) OVER(ORDER BY salary ASC ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS avg, + min(salary) OVER(ORDER BY empno ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cum_min +FROM empsalary +ORDER BY empno ASC; + ++-----------+-------+--------+--------------------+---------+ +| depname | empno | salary | avg | cum_min | ++-----------+-------+--------+--------------------+---------+ +| sales | 1 | 5000 | 5000.0 | 5000 | +| personnel | 2 | 3900 | 3866.6666666666665 | 3900 | +| sales | 3 | 4800 | 4700.0 | 3900 | +| sales | 4 | 4800 | 4866.666666666667 | 3900 | +| personnel | 5 | 3500 | 3700.0 | 3500 | +| develop | 7 | 4200 | 4200.0 | 3500 | +| develop | 8 | 6000 | 5600.0 | 3500 | +| develop | 9 | 4500 | 4500.0 | 3500 | +| develop | 10 | 5200 | 5133.333333333333 | 3500 | +| develop | 11 | 5200 | 5466.666666666667 | 3500 | ++-----------+-------+--------+--------------------+---------+ +``` + +When a query involves multiple window functions, it is possible to write out each one with a separate OVER clause, but this is duplicative and error-prone if the same windowing behavior is wanted for several functions. Instead, each windowing behavior can be named in a WINDOW clause and then referenced in OVER. For example: + +```sql +SELECT sum(salary) OVER w, avg(salary) OVER w +FROM empsalary +WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); +``` + +## Syntax + +The syntax for the OVER-clause is + +``` +function([expr]) + OVER( + [PARTITION BY expr[, …]] + [ORDER BY expr [ ASC | DESC ][, …]] + [ frame_clause ] + ) +``` + +where **frame_clause** is one of: + +``` + { RANGE | ROWS | GROUPS } frame_start + { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end +``` + +and **frame_start** and **frame_end** can be one of + +```sql +UNBOUNDED PRECEDING +offset PRECEDING +CURRENT ROW +offset FOLLOWING +UNBOUNDED FOLLOWING +``` + +where **offset** is an non-negative integer. + +RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). + +## Aggregate functions + +All [aggregate functions](aggregate_functions.md) can be used as window functions. + +## Ranking Functions + +- [cume_dist](#cume_dist) +- [dense_rank](#dense_rank) +- [ntile](#ntile) +- [percent_rank](#percent_rank) +- [rank](#rank) +- [row_number](#row_number) + +### `cume_dist` + +Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). + +``` +cume_dist() +``` + +### `dense_rank` + +Returns the rank of the current row without gaps. This function ranks rows in a dense manner, meaning consecutive ranks are assigned even for identical values. + +``` +dense_rank() +``` + +### `ntile` + +Integer ranging from 1 to the argument value, dividing the partition as equally as possible + +``` +ntile(expression) +``` + +#### Arguments + +- **expression**: An integer describing the number groups the partition should be split into + +### `percent_rank` + +Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. + +``` +percent_rank() +``` + +### `rank` + +Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. + +``` +rank() +``` + +### `row_number` + +Number of the current row within its partition, counting from 1. + +``` +row_number() +``` + +## Analytical Functions + +- [lag](#lag) +- [lead](#lead) + +### `lag` + +Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lag(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. + +### `lead` + +Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lead(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 325a2cc2fcc48..414fa5569cfed 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -29,4 +29,5 @@ workspace = true arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } +paste = "1.0.15" rand = { workspace = true } diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs new file mode 100644 index 0000000000000..4a799ae737d7b --- /dev/null +++ b/test-utils/src/array_gen/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod primitive; +mod string; + +pub use primitive::PrimitiveArrayGenerator; +pub use string::StringArrayGenerator; diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs new file mode 100644 index 0000000000000..0581862d63bd6 --- /dev/null +++ b/test-utils/src/array_gen/primitive.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; +use arrow::datatypes::DataType; +use rand::distributions::Standard; +use rand::prelude::Distribution; +use rand::rngs::StdRng; +use rand::Rng; + +/// Trait for converting type safely from a native type T impl this trait. +pub trait FromNative: std::fmt::Debug + Send + Sync + Copy + Default { + /// Convert native type from i64. + fn from_i64(_: i64) -> Option { + None + } +} + +macro_rules! native_type { + ($t: ty $(, $from:ident)*) => { + impl FromNative for $t { + $( + #[inline] + fn $from(v: $t) -> Option { + Some(v) + } + )* + } + }; +} + +native_type!(i8); +native_type!(i16); +native_type!(i32); +native_type!(i64, from_i64); +native_type!(u8); +native_type!(u16); +native_type!(u32); +native_type!(u64); +native_type!(f32); +native_type!(f64); + +/// Randomly generate primitive array +pub struct PrimitiveArrayGenerator { + /// the total number of strings in the output + pub num_primitives: usize, + /// The number of distinct strings in the columns + pub num_distinct_primitives: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +// TODO: support generating more primitive arrays +impl PrimitiveArrayGenerator { + pub fn gen_data(&mut self) -> ArrayRef + where + A: ArrowPrimitiveType, + A::Native: FromNative, + Standard: Distribution<::Native>, + { + // table of primitives from which to draw + let distinct_primitives: PrimitiveArray = (0..self.num_distinct_primitives) + .map(|_| { + Some(match A::DATA_TYPE { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 => self.rng.gen::(), + + DataType::Date64 => { + // TODO: constrain this range to valid dates if necessary + let date_value = self.rng.gen_range(i64::MIN..=i64::MAX); + let millis_per_day = 86_400_000; + let adjusted_value = date_value - (date_value % millis_per_day); + A::Native::from_i64(adjusted_value).unwrap() + } + + _ => { + let arrow_type = A::DATA_TYPE; + panic!("Unsupported arrow data type: {arrow_type}") + } + }) + }) + .collect(); + + // pick num_primitves randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_primitives) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_primitives > 1 { + let range = 1..(self.num_distinct_primitives as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() + } +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs new file mode 100644 index 0000000000000..fbfa2bb941e00 --- /dev/null +++ b/test-utils/src/array_gen/string.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate string arrays +pub struct StringArrayGenerator { + //// The maximum length of the strings + pub max_len: usize, + /// the total number of strings in the output + pub num_strings: usize, + /// The number of distinct strings in the columns + pub num_distinct_strings: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl StringArrayGenerator { + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + pub fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 3ddba2fec8007..9db8920833ae5 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -22,6 +22,7 @@ use datafusion_common::cast::as_int32_array; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; +pub mod array_gen; mod data_gen; mod string_gen; pub mod tpcds; diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index 530fc15353870..b598241db1e92 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -1,3 +1,4 @@ +use crate::array_gen::StringArrayGenerator; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -14,27 +15,14 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// -// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, RecordBatch, UInt32Array}; + use crate::stagger_batch; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; use rand::{thread_rng, Rng, SeedableRng}; /// Randomly generate strings -pub struct StringBatchGenerator { - //// The maximum length of the strings - pub max_len: usize, - /// the total number of strings in the output - pub num_strings: usize, - /// The number of distinct strings in the columns - pub num_distinct_strings: usize, - /// The percentage of nulls in the columns - pub null_pct: f64, - /// Random number generator - pub rng: StdRng, -} +pub struct StringBatchGenerator(StringArrayGenerator); impl StringBatchGenerator { /// Make batches of random strings with a random length columns "a" and "b". @@ -44,8 +32,8 @@ impl StringBatchGenerator { pub fn make_input_batches(&mut self) -> Vec { // use a random number generator to pick a random sized output let batch = RecordBatch::try_from_iter(vec![ - ("a", self.gen_data::()), - ("b", self.gen_data::()), + ("a", self.0.gen_data::()), + ("b", self.0.gen_data::()), ]) .unwrap(); stagger_batch(batch) @@ -57,9 +45,9 @@ impl StringBatchGenerator { /// if large is true, the array is a LargeStringArray pub fn make_sorted_input_batches(&mut self, large: bool) -> Vec { let array = if large { - self.gen_data::() + self.0.gen_data::() } else { - self.gen_data::() + self.0.gen_data::() }; let array = arrow::compute::sort(&array, None).unwrap(); @@ -68,39 +56,13 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Creates a StringArray or LargeStringArray with random strings according - /// to the parameters of the BatchGenerator - fn gen_data(&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) - .map(|_| Some(random_string(&mut self.rng, self.max_len))) - .collect(); - - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_strings) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_strings > 1 { - let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); - - let options = None; - arrow::compute::take(&distinct_strings, &indicies, options).unwrap() - } - /// Return an set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { let mut cases = vec![]; let mut rng = thread_rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { - for _ in 0..100 { + for _ in 0..10 { // max length of generated strings let max_len = rng.gen_range(1..50); let num_strings = rng.gen_range(1..100); @@ -109,31 +71,15 @@ impl StringBatchGenerator { } else { num_strings }; - cases.push(StringBatchGenerator { + cases.push(StringBatchGenerator(StringArrayGenerator { max_len, num_strings, num_distinct_strings, null_pct, rng: StdRng::from_seed(rng.gen()), - }) + })) } } cases } } - -/// Return a string of random characters of length 1..=max_len -fn random_string(rng: &mut StdRng, max_len: usize) -> String { - // pick characters at random (not just ascii) - match max_len { - 0 => "".to_string(), - 1 => String::from(rng.gen::()), - _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) - .take(len) - .map(char::from) - .collect::() - } - } -}