diff --git a/binding.gyp b/binding.gyp index 0b91d445..0555c625 100644 --- a/binding.gyp +++ b/binding.gyp @@ -26,6 +26,7 @@ "src/duckdb/ub_src_common_row_operations.cpp", "src/duckdb/ub_src_common_serializer.cpp", "src/duckdb/ub_src_common_sort.cpp", + "src/duckdb/ub_src_common_tree_renderer.cpp", "src/duckdb/ub_src_common_types.cpp", "src/duckdb/ub_src_common_types_column.cpp", "src/duckdb/ub_src_common_types_row.cpp", @@ -243,6 +244,7 @@ "src/duckdb/extension/parquet/parquet_writer.cpp", "src/duckdb/extension/parquet/serialize_parquet.cpp", "src/duckdb/extension/parquet/zstd_file_system.cpp", + "src/duckdb/extension/parquet/geo_parquet.cpp", "src/duckdb/third_party/parquet/parquet_constants.cpp", "src/duckdb/third_party/parquet/parquet_types.cpp", "src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp", @@ -272,6 +274,37 @@ "src/duckdb/third_party/zstd/compress/zstd_ldm.cpp", "src/duckdb/third_party/zstd/compress/zstd_opt.cpp", "src/duckdb/third_party/lz4/lz4.cpp", + "src/duckdb/third_party/brotli/common/constants.cpp", + "src/duckdb/third_party/brotli/common/context.cpp", + "src/duckdb/third_party/brotli/common/dictionary.cpp", + "src/duckdb/third_party/brotli/common/platform.cpp", + "src/duckdb/third_party/brotli/common/shared_dictionary.cpp", + "src/duckdb/third_party/brotli/common/transform.cpp", + "src/duckdb/third_party/brotli/dec/bit_reader.cpp", + "src/duckdb/third_party/brotli/dec/decode.cpp", + "src/duckdb/third_party/brotli/dec/huffman.cpp", + "src/duckdb/third_party/brotli/dec/state.cpp", + "src/duckdb/third_party/brotli/enc/backward_references.cpp", + "src/duckdb/third_party/brotli/enc/backward_references_hq.cpp", + "src/duckdb/third_party/brotli/enc/bit_cost.cpp", + "src/duckdb/third_party/brotli/enc/block_splitter.cpp", + "src/duckdb/third_party/brotli/enc/brotli_bit_stream.cpp", + "src/duckdb/third_party/brotli/enc/cluster.cpp", + "src/duckdb/third_party/brotli/enc/command.cpp", + "src/duckdb/third_party/brotli/enc/compound_dictionary.cpp", + "src/duckdb/third_party/brotli/enc/compress_fragment.cpp", + "src/duckdb/third_party/brotli/enc/compress_fragment_two_pass.cpp", + "src/duckdb/third_party/brotli/enc/dictionary_hash.cpp", + "src/duckdb/third_party/brotli/enc/encode.cpp", + "src/duckdb/third_party/brotli/enc/encoder_dict.cpp", + "src/duckdb/third_party/brotli/enc/entropy_encode.cpp", + "src/duckdb/third_party/brotli/enc/fast_log.cpp", + "src/duckdb/third_party/brotli/enc/histogram.cpp", + "src/duckdb/third_party/brotli/enc/literal_cost.cpp", + "src/duckdb/third_party/brotli/enc/memory.cpp", + "src/duckdb/third_party/brotli/enc/metablock.cpp", + "src/duckdb/third_party/brotli/enc/static_dict.cpp", + "src/duckdb/third_party/brotli/enc/utf8_util.cpp", "src/duckdb/extension/icu/./icu-table-range.cpp", "src/duckdb/extension/icu/./icu-makedate.cpp", "src/duckdb/extension/icu/./icu-list-range.cpp", @@ -313,6 +346,10 @@ "src/duckdb/third_party/libpg_query", "src/duckdb/third_party/libpg_query/include", "src/duckdb/third_party/lz4", + "src/duckdb/third_party/brotli/include", + "src/duckdb/third_party/brotli/common", + "src/duckdb/third_party/brotli/dec", + "src/duckdb/third_party/brotli/enc", "src/duckdb/third_party/mbedtls", "src/duckdb/third_party/mbedtls/include", "src/duckdb/third_party/mbedtls/library", @@ -328,6 +365,10 @@ "src/duckdb/third_party/parquet", "src/duckdb/third_party/thrift", "src/duckdb/third_party/lz4", + "src/duckdb/third_party/brotli/include", + "src/duckdb/third_party/brotli/common", + "src/duckdb/third_party/brotli/dec", + "src/duckdb/third_party/brotli/enc", "src/duckdb/third_party/snappy", "src/duckdb/third_party/zstd/include", "src/duckdb/third_party/mbedtls", diff --git a/src/duckdb/extension/icu/icu-dateadd.cpp b/src/duckdb/extension/icu/icu-dateadd.cpp index 46f1fdfe..284e2561 100644 --- a/src/duckdb/extension/icu/icu-dateadd.cpp +++ b/src/duckdb/extension/icu/icu-dateadd.cpp @@ -101,13 +101,15 @@ timestamp_t ICUCalendarAdd::Operation(timestamp_t timestamp, interval_t interval calendar->add(UCAL_MINUTE, interval_m, status); CalendarAddHour(calendar, interval_h, status); - calendar->add(UCAL_DATE, interval.days, status); + // PG Adds months before days calendar->add(UCAL_MONTH, interval.months, status); + calendar->add(UCAL_DATE, interval.days, status); } else { - // Add interval fields from highest to lowest (ragged to non-ragged) + // PG Adds months before days calendar->add(UCAL_MONTH, interval.months, status); calendar->add(UCAL_DATE, interval.days, status); + // Add interval fields from highest to lowest (ragged to non-ragged) CalendarAddHour(calendar, interval_h, status); calendar->add(UCAL_MINUTE, interval_m, status); calendar->add(UCAL_SECOND, interval_s, status); diff --git a/src/duckdb/extension/icu/icu-datefunc.cpp b/src/duckdb/extension/icu/icu-datefunc.cpp index aeab8379..b0202f8d 100644 --- a/src/duckdb/extension/icu/icu-datefunc.cpp +++ b/src/duckdb/extension/icu/icu-datefunc.cpp @@ -73,6 +73,10 @@ unique_ptr ICUDateFunc::Bind(ClientContext &context, ScalarFunctio void ICUDateFunc::SetTimeZone(icu::Calendar *calendar, const string_t &tz_id) { auto tz = icu_66::TimeZone::createTimeZone(icu::UnicodeString::fromUTF8(icu::StringPiece(tz_id.GetString()))); + if (*tz == icu::TimeZone::getUnknown()) { + delete tz; + throw NotImplementedException("Unknown TimeZone '%s'", tz_id.GetString()); + } calendar->adoptTimeZone(tz); } @@ -83,7 +87,7 @@ timestamp_t ICUDateFunc::GetTimeUnsafe(icu::Calendar *calendar, uint64_t micros) if (U_FAILURE(status)) { throw InternalException("Unable to get ICU calendar time."); } - return timestamp_t(millis * Interval::MICROS_PER_MSEC + micros); + return timestamp_t(millis * Interval::MICROS_PER_MSEC + int64_t(micros)); } bool ICUDateFunc::TryGetTime(icu::Calendar *calendar, uint64_t micros, timestamp_t &result) { @@ -98,7 +102,7 @@ bool ICUDateFunc::TryGetTime(icu::Calendar *calendar, uint64_t micros, timestamp if (!TryMultiplyOperator::Operation(millis, Interval::MICROS_PER_MSEC, millis)) { return false; } - if (!TryAddOperator::Operation(millis, micros, millis)) { + if (!TryAddOperator::Operation(millis, int64_t(micros), millis)) { return false; } diff --git a/src/duckdb/extension/icu/icu-datesub.cpp b/src/duckdb/extension/icu/icu-datesub.cpp index 708dbd52..c56ccebc 100644 --- a/src/duckdb/extension/icu/icu-datesub.cpp +++ b/src/duckdb/extension/icu/icu-datesub.cpp @@ -209,6 +209,17 @@ struct ICUCalendarDiff : public ICUDateFunc { return sub_func(calendar, start_date, end_date); } + static part_trunc_t DiffTruncationFactory(DatePartSpecifier type) { + switch (type) { + case DatePartSpecifier::WEEK: + // Weeks are computed without anchors + return TruncationFactory(DatePartSpecifier::DAY); + default: + break; + } + return TruncationFactory(type); + } + template static void ICUDateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); @@ -229,7 +240,7 @@ struct ICUCalendarDiff : public ICUDateFunc { } else { const auto specifier = ConstantVector::GetData(part_arg)->GetString(); const auto part = GetDatePartSpecifier(specifier); - auto trunc_func = TruncationFactory(part); + auto trunc_func = DiffTruncationFactory(part); auto sub_func = SubtractFactory(part); BinaryExecutor::ExecuteWithNulls( startdate_arg, enddate_arg, result, args.size(), @@ -248,7 +259,7 @@ struct ICUCalendarDiff : public ICUDateFunc { [&](string_t specifier, T start_date, T end_date, ValidityMask &mask, idx_t idx) { if (Timestamp::IsFinite(start_date) && Timestamp::IsFinite(end_date)) { const auto part = GetDatePartSpecifier(specifier.GetString()); - auto trunc_func = TruncationFactory(part); + auto trunc_func = DiffTruncationFactory(part); auto sub_func = SubtractFactory(part); return DifferenceFunc(calendar, start_date, end_date, trunc_func, sub_func); } else { diff --git a/src/duckdb/extension/icu/icu-strptime.cpp b/src/duckdb/extension/icu/icu-strptime.cpp index d54d15a3..c7a5351f 100644 --- a/src/duckdb/extension/icu/icu-strptime.cpp +++ b/src/duckdb/extension/icu/icu-strptime.cpp @@ -68,15 +68,15 @@ struct ICUStrptime : public ICUDateFunc { } // Now get the parts in the given time zone - uint64_t micros = 0; + uint64_t micros = parsed.GetMicros(); calendar->set(UCAL_EXTENDED_YEAR, parsed.data[0]); // strptime doesn't understand eras calendar->set(UCAL_MONTH, parsed.data[1] - 1); calendar->set(UCAL_DATE, parsed.data[2]); calendar->set(UCAL_HOUR_OF_DAY, parsed.data[3]); calendar->set(UCAL_MINUTE, parsed.data[4]); calendar->set(UCAL_SECOND, parsed.data[5]); - calendar->set(UCAL_MILLISECOND, parsed.data[6] / Interval::MICROS_PER_MSEC); - micros = parsed.data[6] % Interval::MICROS_PER_MSEC; + calendar->set(UCAL_MILLISECOND, micros / Interval::MICROS_PER_MSEC); + micros %= Interval::MICROS_PER_MSEC; // This overrides the TZ setting, so only use it if an offset was parsed. // Note that we don't bother/worry about the DST setting because the two just combine. @@ -158,7 +158,7 @@ struct ICUStrptime : public ICUDateFunc { } } - static bind_scalar_function_t bind_strptime; + static bind_scalar_function_t bind_strptime; // NOLINT static duckdb::unique_ptr StrpTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { @@ -194,7 +194,7 @@ struct ICUStrptime : public ICUDateFunc { throw InvalidInputException("strptime format list must not be empty"); } vector formats; - bool has_tz = true; + bool has_tz = false; for (const auto &child : children) { format_string = child.ToString(); format.format_specifier = format_string; @@ -341,7 +341,7 @@ struct ICUStrptime : public ICUDateFunc { } }; -bind_scalar_function_t ICUStrptime::bind_strptime = nullptr; +bind_scalar_function_t ICUStrptime::bind_strptime = nullptr; // NOLINT struct ICUStrftime : public ICUDateFunc { static void ParseFormatSpecifier(string_t &format_str, StrfTimeFormat &format) { diff --git a/src/duckdb/extension/icu/icu-table-range.cpp b/src/duckdb/extension/icu/icu-table-range.cpp index 9f466a7d..f7efd856 100644 --- a/src/duckdb/extension/icu/icu-table-range.cpp +++ b/src/duckdb/extension/icu/icu-table-range.cpp @@ -13,14 +13,13 @@ namespace duckdb { struct ICUTableRange { using CalendarPtr = unique_ptr; - struct BindData : public TableFunctionData { - BindData(const BindData &other) + struct ICURangeBindData : public TableFunctionData { + ICURangeBindData(const ICURangeBindData &other) : TableFunctionData(other), tz_setting(other.tz_setting), cal_setting(other.cal_setting), - calendar(other.calendar->clone()), start(other.start), end(other.end), increment(other.increment), - inclusive_bound(other.inclusive_bound), greater_than_check(other.greater_than_check) { + calendar(other.calendar->clone()) { } - explicit BindData(ClientContext &context) { + explicit ICURangeBindData(ClientContext &context) { Value tz_value; if (context.TryGetCurrentSetting("TimeZone", tz_value)) { tz_setting = tz_value.ToString(); @@ -48,6 +47,15 @@ struct ICUTableRange { string tz_setting; string cal_setting; CalendarPtr calendar; + }; + + struct ICURangeLocalState : public LocalTableFunctionState { + ICURangeLocalState() { + } + + bool initialized_row = false; + idx_t current_input_row = 0; + timestamp_t current_state; timestamp_t start; timestamp_t end; @@ -55,17 +63,6 @@ struct ICUTableRange { bool inclusive_bound; bool greater_than_check; - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return other.start == start && other.end == end && other.increment == increment && - other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check && - *calendar == *other.calendar; - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - bool Finished(timestamp_t current_value) const { if (greater_than_check) { if (inclusive_bound) { @@ -84,107 +81,129 @@ struct ICUTableRange { }; template - static unique_ptr Bind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(context); - - auto &inputs = input.inputs; - D_ASSERT(inputs.size() == 3); - for (const auto &value : inputs) { - if (value.IsNull()) { - throw BinderException("RANGE with NULL bounds is not supported"); + static void GenerateRangeDateTimeParameters(DataChunk &input, idx_t row_id, ICURangeLocalState &result) { + input.Flatten(); + for (idx_t c = 0; c < input.ColumnCount(); c++) { + if (FlatVector::IsNull(input.data[c], row_id)) { + result.start = timestamp_t(0); + result.end = timestamp_t(0); + result.increment = interval_t(); + result.greater_than_check = true; + result.inclusive_bound = false; + return; } } - result->start = inputs[0].GetValue(); - result->end = inputs[1].GetValue(); - result->increment = inputs[2].GetValue(); + + result.start = FlatVector::GetValue(input.data[0], row_id); + result.end = FlatVector::GetValue(input.data[1], row_id); + result.increment = FlatVector::GetValue(input.data[2], row_id); // Infinities either cause errors or infinite loops, so just ban them - if (!Timestamp::IsFinite(result->start) || !Timestamp::IsFinite(result->end)) { + if (!Timestamp::IsFinite(result.start) || !Timestamp::IsFinite(result.end)) { throw BinderException("RANGE with infinite bounds is not supported"); } - if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { + if (result.increment.months == 0 && result.increment.days == 0 && result.increment.micros == 0) { throw BinderException("interval cannot be 0!"); } // all elements should point in the same direction - if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { - if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { + if (result.increment.months > 0 || result.increment.days > 0 || result.increment.micros > 0) { + if (result.increment.months < 0 || result.increment.days < 0 || result.increment.micros < 0) { throw BinderException("RANGE with composite interval that has mixed signs is not supported"); } - result->greater_than_check = true; - if (result->start > result->end) { + result.greater_than_check = true; + if (result.start > result.end) { throw BinderException( "start is bigger than end, but increment is positive: cannot generate infinite series"); } } else { - result->greater_than_check = false; - if (result->start < result->end) { + result.greater_than_check = false; + if (result.start < result.end) { throw BinderException( "start is smaller than end, but increment is negative: cannot generate infinite series"); } } - return_types.push_back(inputs[0].type()); + result.inclusive_bound = GENERATE_SERIES; + } + + template + static unique_ptr Bind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto result = make_uniq(context); + + return_types.push_back(LogicalType::TIMESTAMP_TZ); if (GENERATE_SERIES) { - // generate_series has inclusive bounds on the RHS - result->inclusive_bound = true; names.emplace_back("generate_series"); } else { - result->inclusive_bound = false; names.emplace_back("range"); } return std::move(result); } - struct State : public GlobalTableFunctionState { - explicit State(timestamp_t start_p) : current_state(start_p) { - } - - timestamp_t current_state; - bool finished = false; - }; - - static unique_ptr Init(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - return make_uniq(bind_data.start); + static unique_ptr RangeDateTimeLocalInit(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state) { + return make_uniq(); } - static void ICUTableRangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); + template + static OperatorResultType ICUTableRangeFunction(ExecutionContext &context, TableFunctionInput &data_p, + DataChunk &input, DataChunk &output) { + auto &bind_data = data_p.bind_data->Cast(); + auto &state = data_p.local_state->Cast(); CalendarPtr calendar_ptr(bind_data.calendar->clone()); auto calendar = calendar_ptr.get(); - auto &state = data_p.global_state->Cast(); - if (state.finished) { - return; - } - - idx_t size = 0; - auto data = FlatVector::GetData(output.data[0]); while (true) { - data[size++] = state.current_state; - state.current_state = ICUDateFunc::Add(calendar, state.current_state, bind_data.increment); - if (bind_data.Finished(state.current_state)) { - state.finished = true; - break; + if (!state.initialized_row) { + // initialize for the current input row + if (state.current_input_row >= input.size()) { + // ran out of rows + state.current_input_row = 0; + state.initialized_row = false; + return OperatorResultType::NEED_MORE_INPUT; + } + GenerateRangeDateTimeParameters(input, state.current_input_row, state); + state.initialized_row = true; + state.current_state = state.start; + } + idx_t size = 0; + auto data = FlatVector::GetData(output.data[0]); + while (true) { + if (state.Finished(state.current_state)) { + break; + } + data[size++] = state.current_state; + state.current_state = ICUDateFunc::Add(calendar, state.current_state, state.increment); + if (size >= STANDARD_VECTOR_SIZE) { + break; + } } - if (size >= STANDARD_VECTOR_SIZE) { - break; + if (size == 0) { + // move to next row + state.current_input_row++; + state.initialized_row = false; + continue; } + output.SetCardinality(size); + return OperatorResultType::HAVE_MORE_OUTPUT; } - output.SetCardinality(size); } static void AddICUTableRangeFunction(DatabaseInstance &db) { TableFunctionSet range("range"); - range.AddFunction(TableFunction({LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, - ICUTableRangeFunction, Bind, Init)); + TableFunction range_function({LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, + nullptr, Bind, nullptr, RangeDateTimeLocalInit); + range_function.in_out_function = ICUTableRangeFunction; + range.AddFunction(range_function); ExtensionUtil::AddFunctionOverload(db, range); // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS TableFunctionSet generate_series("generate_series"); - generate_series.AddFunction( - TableFunction({LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, - ICUTableRangeFunction, Bind, Init)); + TableFunction generate_series_function( + {LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, nullptr, Bind, nullptr, + RangeDateTimeLocalInit); + generate_series_function.in_out_function = ICUTableRangeFunction; + generate_series.AddFunction(generate_series_function); ExtensionUtil::AddFunctionOverload(db, generate_series); } }; diff --git a/src/duckdb/extension/icu/icu-timebucket.cpp b/src/duckdb/extension/icu/icu-timebucket.cpp index 1d928704..d7be40dc 100644 --- a/src/duckdb/extension/icu/icu-timebucket.cpp +++ b/src/duckdb/extension/icu/icu-timebucket.cpp @@ -63,6 +63,9 @@ struct ICUTimeBucket : public ICUDateFunc { static inline timestamp_t WidthConvertibleToMicrosCommon(int64_t bucket_width_micros, const timestamp_t ts, const timestamp_t origin, icu::Calendar *calendar) { + if (!bucket_width_micros) { + throw OutOfRangeException("Can't bucket using zero microseconds"); + } int64_t ts_micros = SubtractOperatorOverflowCheck::Operation( Timestamp::GetEpochMicroSeconds(ts), Timestamp::GetEpochMicroSeconds(origin)); int64_t result_micros = (ts_micros / bucket_width_micros) * bucket_width_micros; @@ -76,6 +79,9 @@ struct ICUTimeBucket : public ICUDateFunc { static inline timestamp_t WidthConvertibleToDaysCommon(int32_t bucket_width_days, const timestamp_t ts, const timestamp_t origin, icu::Calendar *calendar) { + if (!bucket_width_days) { + throw OutOfRangeException("Can't bucket using zero days"); + } const auto sub_days = SubtractFactory(DatePartSpecifier::DAY); int64_t ts_days = sub_days(calendar, origin, ts); @@ -95,6 +101,9 @@ struct ICUTimeBucket : public ICUDateFunc { static inline timestamp_t WidthConvertibleToMonthsCommon(int32_t bucket_width_months, const timestamp_t ts, const timestamp_t origin, icu::Calendar *calendar) { + if (!bucket_width_months) { + throw OutOfRangeException("Can't bucket using zero months"); + } const auto trunc_months = TruncationFactory(DatePartSpecifier::MONTH); const auto sub_months = SubtractFactory(DatePartSpecifier::MONTH); @@ -106,8 +115,9 @@ struct ICUTimeBucket : public ICUDateFunc { trunc_months(calendar, tmp_micros); timestamp_t truncated_origin = GetTimeUnsafe(calendar, tmp_micros); - int64_t ts_months = sub_months(calendar, truncated_origin, truncated_ts); - int64_t result_months = (ts_months / bucket_width_months) * bucket_width_months; + int32_t ts_months = + NumericCast(sub_months(calendar, truncated_origin, truncated_ts)); // NOLINT + auto result_months = (ts_months / bucket_width_months) * bucket_width_months; if (result_months < NumericLimits::Minimum() || result_months > NumericLimits::Maximum()) { throw OutOfRangeException("Timestamp out of range"); } diff --git a/src/duckdb/extension/icu/icu-timezone.cpp b/src/duckdb/extension/icu/icu-timezone.cpp index bd0f9d21..0a91ef75 100644 --- a/src/duckdb/extension/icu/icu-timezone.cpp +++ b/src/duckdb/extension/icu/icu-timezone.cpp @@ -125,7 +125,7 @@ struct ICUFromNaiveTimestamp : public ICUDateFunc { int32_t secs; int32_t frac; Time::Convert(local_time, hr, mn, secs, frac); - int32_t millis = frac / Interval::MICROS_PER_MSEC; + int32_t millis = frac / int32_t(Interval::MICROS_PER_MSEC); uint64_t micros = frac % Interval::MICROS_PER_MSEC; // Use them to set the time in the time zone @@ -199,7 +199,7 @@ struct ICUToNaiveTimestamp : public ICUDateFunc { } // Extract the time zone parts - auto micros = SetTime(calendar, instant); + auto micros = int32_t(SetTime(calendar, instant)); const auto era = ExtractField(calendar, UCAL_ERA); const auto year = ExtractField(calendar, UCAL_YEAR); const auto mm = ExtractField(calendar, UCAL_MONTH) + 1; @@ -216,7 +216,7 @@ struct ICUToNaiveTimestamp : public ICUDateFunc { const auto secs = ExtractField(calendar, UCAL_SECOND); const auto millis = ExtractField(calendar, UCAL_MILLISECOND); - micros += millis * Interval::MICROS_PER_MSEC; + micros += millis * int32_t(Interval::MICROS_PER_MSEC); dtime_t local_time = Time::FromTime(hr, mn, secs, micros); timestamp_t naive; diff --git a/src/duckdb/extension/icu/icu_extension.cpp b/src/duckdb/extension/icu/icu_extension.cpp index 46efa60b..e5c03880 100644 --- a/src/duckdb/extension/icu/icu_extension.cpp +++ b/src/duckdb/extension/icu/icu_extension.cpp @@ -26,6 +26,7 @@ #include "include/icu_extension.hpp" #include "unicode/calendar.h" #include "unicode/coll.h" +#include "unicode/errorcode.h" #include "unicode/sortkey.h" #include "unicode/stringpiece.h" #include "unicode/timezone.h" @@ -39,6 +40,10 @@ struct IcuBindData : public FunctionData { duckdb::unique_ptr collator; string language; string country; + string tag; + + explicit IcuBindData(duckdb::unique_ptr collator_p) : collator(std::move(collator_p)) { + } IcuBindData(string language_p, string country_p) : language(std::move(language_p)), country(std::move(country_p)) { UErrorCode status = U_ZERO_ERROR; @@ -54,13 +59,32 @@ struct IcuBindData : public FunctionData { } } + explicit IcuBindData(string tag_p) : tag(std::move(tag_p)) { + UErrorCode status = U_ZERO_ERROR; + UCollator *ucollator = ucol_open(tag.c_str(), &status); + if (U_FAILURE(status)) { + auto error_name = u_errorName(status); + throw InvalidInputException("Failed to create ICU collator with tag %s: %s", tag, error_name); + } + collator = unique_ptr(icu::Collator::fromUCollator(ucollator)); + } + + static duckdb::unique_ptr CreateInstance(string language, string country, string tag) { + //! give priority to tagged collation + if (!tag.empty()) { + return make_uniq(tag); + } else { + return make_uniq(language, country); + } + } + duckdb::unique_ptr Copy() const override { - return make_uniq(language, country); + return CreateInstance(language, country, tag); } bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); - return language == other.language && country == other.country; + return language == other.language && country == other.country && tag == other.tag; } static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, @@ -68,15 +92,17 @@ struct IcuBindData : public FunctionData { auto &bind_data = bind_data_p->Cast(); serializer.WriteProperty(100, "language", bind_data.language); serializer.WriteProperty(101, "country", bind_data.country); + serializer.WritePropertyWithDefault(102, "tag", bind_data.tag); } static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &function) { string language; string country; + string tag; deserializer.ReadProperty(100, "language", language); deserializer.ReadProperty(101, "country", country); - - return make_uniq(language, country); + deserializer.ReadPropertyWithDefault(102, "tag", tag); + return CreateInstance(language, country, tag); } static const string FUNCTION_PREFIX; @@ -94,7 +120,7 @@ const string IcuBindData::FUNCTION_PREFIX = "icu_collate_"; static int32_t ICUGetSortKey(icu::Collator &collator, string_t input, duckdb::unique_ptr &buffer, int32_t &buffer_size) { icu::UnicodeString unicode_string = - icu::UnicodeString::fromUTF8(icu::StringPiece(input.GetData(), input.GetSize())); + icu::UnicodeString::fromUTF8(icu::StringPiece(input.GetData(), int32_t(input.GetSize()))); int32_t string_size = collator.getSortKey(unicode_string, reinterpret_cast(buffer.get()), buffer_size); if (string_size > buffer_size) { // have to resize the buffer @@ -135,6 +161,10 @@ static void ICUCollateFunction(DataChunk &args, ExpressionState &state, Vector & static duckdb::unique_ptr ICUCollateBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { + //! Return a tagged collator + if (!bound_function.extra_info.empty()) { + return make_uniq(bound_function.extra_info); + } const auto collation = IcuBindData::DecodeFunctionName(bound_function.name); auto splits = StringUtil::Split(collation, "_"); @@ -156,6 +186,10 @@ static duckdb::unique_ptr ICUSortKeyBind(ClientContext &context, S if (val.IsNull()) { throw NotImplementedException("ICU_SORT_KEY(VARCHAR, VARCHAR) expected a non-null collation"); } + //! Verify tagged collation + if (!bound_function.extra_info.empty()) { + return make_uniq(bound_function.extra_info); + } auto splits = StringUtil::Split(StringValue::Get(val), "_"); if (splits.size() == 1) { return make_uniq(splits[0], ""); @@ -166,20 +200,23 @@ static duckdb::unique_ptr ICUSortKeyBind(ClientContext &context, S } } -static ScalarFunction GetICUCollateFunction(const string &collation) { +static ScalarFunction GetICUCollateFunction(const string &collation, const string &tag) { string fname = IcuBindData::EncodeFunctionName(collation); ScalarFunction result(fname, {LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUCollateFunction, ICUCollateBind); + //! collation tag is added into the Function extra info + result.extra_info = tag; result.serialize = IcuBindData::Serialize; result.deserialize = IcuBindData::Deserialize; return result; } static void SetICUTimeZone(ClientContext &context, SetScope scope, Value ¶meter) { - icu::StringPiece utf8(StringValue::Get(parameter)); + auto str = StringValue::Get(parameter); + icu::StringPiece utf8(str); const auto uid = icu::UnicodeString::fromUTF8(utf8); duckdb::unique_ptr tz(icu::TimeZone::createTimeZone(uid)); if (*tz == icu::TimeZone::getUnknown()) { - throw NotImplementedException("Unknown TimeZone setting"); + throw NotImplementedException("Unknown TimeZone '%s'", str); } } @@ -259,9 +296,24 @@ static void LoadInternal(DuckDB &ddb) { } collation = StringUtil::Lower(collation); - CreateCollationInfo info(collation, GetICUCollateFunction(collation), false, false); + CreateCollationInfo info(collation, GetICUCollateFunction(collation, ""), false, false); ExtensionUtil::RegisterCollation(db, info); } + + /** + * This collation function is inpired on the Postgres "ignore_accents": + * See: https://www.postgresql.org/docs/current/collation.html + * CREATE COLLATION ignore_accents (provider = icu, locale = 'und-u-ks-level1-kc-true', deterministic = false); + * + * Also, according with the source file: postgres/src/backend/utils/adt/pg_locale.c. + * "und-u-kc-ks-level1" is converted to the equivalent ICU format locale ID, + * e.g. "und@colcaselevel=yes;colstrength=primary" + * + */ + CreateCollationInfo info("icu_noaccent", GetICUCollateFunction("noaccent", "und-u-ks-level1-kc-true"), false, + false); + ExtensionUtil::RegisterCollation(db, info); + ScalarFunction sort_key("icu_sort_key", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUCollateFunction, ICUSortKeyBind); ExtensionUtil::RegisterFunction(db, sort_key); diff --git a/src/duckdb/extension/json/include/json_executors.hpp b/src/duckdb/extension/json/include/json_executors.hpp index 78da4526..0eeff5e4 100644 --- a/src/duckdb/extension/json/include/json_executors.hpp +++ b/src/duckdb/extension/json/include/json_executors.hpp @@ -13,26 +13,28 @@ namespace duckdb { +template +using json_function_t = std::function; + struct JSONExecutors { public: //! Single-argument JSON read function, i.e. json_type('[1, 2, 3]') template - static void UnaryExecute(DataChunk &args, ExpressionState &state, Vector &result, - std::function fun) { + static void UnaryExecute(DataChunk &args, ExpressionState &state, Vector &result, const json_function_t fun) { auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator.GetYYAlc(); auto &inputs = args.data[0]; - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { - auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); - return fun(doc->root, alc, result); - }); + UnaryExecutor::ExecuteWithNulls( + inputs, result, args.size(), [&](string_t input, ValidityMask &mask, idx_t idx) { + auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); + return fun(doc->root, alc, result, mask, idx); + }); } //! Two-argument JSON read function (with path query), i.e. json_type('[1, 2, 3]', '$[0]') - template - static void BinaryExecute(DataChunk &args, ExpressionState &state, Vector &result, - std::function fun) { + template + static void BinaryExecute(DataChunk &args, ExpressionState &state, Vector &result, const json_function_t fun) { auto &func_expr = state.expr.Cast(); const auto &info = func_expr.bind_info->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); @@ -48,11 +50,11 @@ struct JSONExecutors { auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, lstate.json_allocator.GetYYAlc()); auto val = JSONCommon::GetUnsafe(doc->root, ptr, len); - if (!val || (NULL_IF_NULL && unsafe_yyjson_is_null(val))) { + if (SET_NULL_IF_NOT_FOUND && !val) { mask.SetInvalid(idx); return T {}; } else { - return fun(val, alc, result); + return fun(val, alc, result, mask, idx); } }); } else { @@ -76,11 +78,7 @@ struct JSONExecutors { for (idx_t i = 0; i < vals.size(); i++) { auto &val = vals[i]; D_ASSERT(val != nullptr); // Wildcard extract shouldn't give back nullptrs - if (NULL_IF_NULL && unsafe_yyjson_is_null(val)) { - child_validity.SetInvalid(current_size + i); - } else { - child_vals[current_size + i] = fun(val, alc, result); - } + child_vals[current_size + i] = fun(val, alc, result, child_validity, current_size + i); } ListVector::SetListSize(result, new_size); @@ -95,11 +93,11 @@ struct JSONExecutors { inputs, paths, result, args.size(), [&](string_t input, string_t path, ValidityMask &mask, idx_t idx) { auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, lstate.json_allocator.GetYYAlc()); auto val = JSONCommon::Get(doc->root, path); - if (!val || unsafe_yyjson_is_null(val)) { + if (SET_NULL_IF_NOT_FOUND && !val) { mask.SetInvalid(idx); return T {}; } else { - return fun(val, alc, result); + return fun(val, alc, result, mask, idx); } }); } @@ -109,9 +107,8 @@ struct JSONExecutors { } //! JSON read function with list of path queries, i.e. json_type('[1, 2, 3]', ['$[0]', '$[1]']) - template - static void ExecuteMany(DataChunk &args, ExpressionState &state, Vector &result, - std::function fun) { + template + static void ExecuteMany(DataChunk &args, ExpressionState &state, Vector &result, const json_function_t fun) { auto &func_expr = state.expr.Cast(); const auto &info = func_expr.bind_info->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); @@ -148,10 +145,10 @@ struct JSONExecutors { for (idx_t path_i = 0; path_i < num_paths; path_i++) { auto child_idx = offset + path_i; val = JSONCommon::GetUnsafe(doc->root, info.ptrs[path_i], info.lens[path_i]); - if (!val || (NULL_IF_NULL && unsafe_yyjson_is_null(val))) { + if (SET_NULL_IF_NOT_FOUND && !val) { child_validity.SetInvalid(child_idx); } else { - child_data[child_idx] = fun(val, alc, child); + child_data[child_idx] = fun(val, alc, child, child_validity, child_idx); } } diff --git a/src/duckdb/extension/json/include/json_functions.hpp b/src/duckdb/extension/json/include/json_functions.hpp index 51fae296..cd19f373 100644 --- a/src/duckdb/extension/json/include/json_functions.hpp +++ b/src/duckdb/extension/json/include/json_functions.hpp @@ -96,13 +96,17 @@ class JSONFunctions { static ScalarFunctionSet GetArrayLengthFunction(); static ScalarFunctionSet GetContainsFunction(); + static ScalarFunctionSet GetExistsFunction(); static ScalarFunctionSet GetKeysFunction(); static ScalarFunctionSet GetTypeFunction(); static ScalarFunctionSet GetValidFunction(); + static ScalarFunctionSet GetValueFunction(); static ScalarFunctionSet GetSerializeSqlFunction(); static ScalarFunctionSet GetDeserializeSqlFunction(); static ScalarFunctionSet GetSerializePlanFunction(); + static ScalarFunctionSet GetPrettyPrintFunction(); + static PragmaFunctionSet GetExecuteJsonSerializedSqlPragmaFunction(); template diff --git a/src/duckdb/extension/json/include/json_scan.hpp b/src/duckdb/extension/json/include/json_scan.hpp index 1e0b9dc2..a2ad431b 100644 --- a/src/duckdb/extension/json/include/json_scan.hpp +++ b/src/duckdb/extension/json/include/json_scan.hpp @@ -124,12 +124,15 @@ struct JSONScanData : public TableFunctionData { idx_t max_depth = NumericLimits::Maximum(); //! We divide the number of appearances of each JSON field by the auto-detection sample size //! If the average over the fields of an object is less than this threshold, - //! we default to the JSON type for this object rather than the shredded type + //! we default to the MAP type with value type of merged field types double field_appearance_threshold = 0.1; //! The maximum number of files we sample to sample sample_size rows idx_t maximum_sample_files = 32; //! Whether we auto-detect and convert JSON strings to integers bool convert_strings_to_integers = false; + //! If a struct contains more fields than this threshold with at least 80% similar types, + //! we infer it as MAP type + idx_t map_inference_threshold = 25; //! All column names (in order) vector names; @@ -237,7 +240,8 @@ struct JSONScanLocalState { void SkipOverArrayStart(); - void ReadAndAutoDetect(JSONScanGlobalState &gstate, AllocatedData &buffer, optional_idx &buffer_index); + void ReadAndAutoDetect(JSONScanGlobalState &gstate, AllocatedData &buffer, optional_idx &buffer_index, + bool &file_done); bool ReconstructFirstObject(JSONScanGlobalState &gstate); void ParseNextChunk(JSONScanGlobalState &gstate); diff --git a/src/duckdb/extension/json/include/json_structure.hpp b/src/duckdb/extension/json/include/json_structure.hpp index 1f008777..3102dc90 100644 --- a/src/duckdb/extension/json/include/json_structure.hpp +++ b/src/duckdb/extension/json/include/json_structure.hpp @@ -19,7 +19,8 @@ struct StrpTimeFormat; struct JSONStructureNode { public: JSONStructureNode(); - JSONStructureNode(yyjson_val *key_p, yyjson_val *val_p, const bool ignore_errors); + JSONStructureNode(const char *key_ptr, const size_t key_len); + JSONStructureNode(yyjson_val *key_p, yyjson_val *val_p, bool ignore_errors); //! Disable copy constructors JSONStructureNode(const JSONStructureNode &other) = delete; @@ -31,7 +32,7 @@ struct JSONStructureNode { JSONStructureDescription &GetOrCreateDescription(LogicalTypeId type); bool ContainsVarchar() const; - void InitializeCandidateTypes(const idx_t max_depth, const bool convert_strings_to_integers, idx_t depth = 0); + void InitializeCandidateTypes(idx_t max_depth, bool convert_strings_to_integers, idx_t depth = 0); void RefineCandidateTypes(yyjson_val *vals[], idx_t val_count, Vector &string_vector, ArenaAllocator &allocator, DateFormatMap &date_format_map); @@ -43,14 +44,15 @@ struct JSONStructureNode { void RefineCandidateTypesString(yyjson_val *vals[], idx_t val_count, Vector &string_vector, DateFormatMap &date_format_map); void EliminateCandidateTypes(idx_t vec_count, Vector &string_vector, DateFormatMap &date_format_map); - bool EliminateCandidateFormats(idx_t vec_count, Vector &string_vector, Vector &result_vector, + bool EliminateCandidateFormats(idx_t vec_count, Vector &string_vector, const Vector &result_vector, vector &formats); public: - duckdb::unique_ptr key; + unique_ptr key; bool initialized = false; vector descriptions; idx_t count; + idx_t null_count; }; struct JSONStructureDescription { @@ -64,7 +66,8 @@ struct JSONStructureDescription { JSONStructureDescription &operator=(JSONStructureDescription &&) noexcept; JSONStructureNode &GetOrCreateChild(); - JSONStructureNode &GetOrCreateChild(yyjson_val *key, yyjson_val *val, const bool ignore_errors); + JSONStructureNode &GetOrCreateChild(const char *key_ptr, size_t key_size); + JSONStructureNode &GetOrCreateChild(yyjson_val *key, yyjson_val *val, bool ignore_errors); public: //! Type of this description @@ -80,10 +83,10 @@ struct JSONStructureDescription { struct JSONStructure { public: - static void ExtractStructure(yyjson_val *val, JSONStructureNode &node, const bool ignore_errors); - static LogicalType StructureToType(ClientContext &context, const JSONStructureNode &node, const idx_t max_depth, - const double field_appearance_threshold, idx_t depth = 0, - idx_t sample_count = DConstants::INVALID_INDEX); + static void ExtractStructure(yyjson_val *val, JSONStructureNode &node, bool ignore_errors); + static LogicalType StructureToType(ClientContext &context, const JSONStructureNode &node, idx_t max_depth, + double field_appearance_threshold, idx_t map_inference_threshold, + idx_t depth = 0, const LogicalType &null_type = LogicalType::JSON()); }; } // namespace duckdb diff --git a/src/duckdb/extension/json/json_common.cpp b/src/duckdb/extension/json/json_common.cpp index edb961e2..bc5f1412 100644 --- a/src/duckdb/extension/json/json_common.cpp +++ b/src/duckdb/extension/json/json_common.cpp @@ -35,15 +35,23 @@ string ThrowPathError(const char *ptr, const char *end, const bool binder) { struct JSONKeyReadResult { public: static inline JSONKeyReadResult Empty() { - return {idx_t(0), string()}; + return {idx_t(0), false, string()}; } static inline JSONKeyReadResult WildCard() { - return {1, "*"}; + return {1, false, "*"}; + } + + static inline JSONKeyReadResult RecWildCard() { + return {2, true, "*"}; + } + + static inline JSONKeyReadResult RecWildCardShortcut() { + return {1, true, "*"}; } inline bool IsValid() { - return chars_read != 0; + return (chars_read != 0); } inline bool IsWildCard() { @@ -52,13 +60,14 @@ struct JSONKeyReadResult { public: idx_t chars_read; + bool recursive; string key; }; static inline JSONKeyReadResult ReadString(const char *ptr, const char *const end, const bool escaped) { const char *const before = ptr; if (escaped) { - auto key = make_unsafe_uniq_array(end - ptr); + auto key = make_unsafe_uniq_array_uninitialized(end - ptr); idx_t key_len = 0; bool backslash = false; @@ -82,7 +91,7 @@ static inline JSONKeyReadResult ReadString(const char *ptr, const char *const en if (ptr == end || backslash) { return JSONKeyReadResult::Empty(); } else { - return {idx_t(ptr - before), string(key.get(), key_len)}; + return {idx_t(ptr - before), false, string(key.get(), key_len)}; } } else { while (ptr != end) { @@ -91,7 +100,7 @@ static inline JSONKeyReadResult ReadString(const char *ptr, const char *const en } ptr++; } - return {idx_t(ptr - before), string(before, ptr - before)}; + return {idx_t(ptr - before), false, string(before, ptr - before)}; } } @@ -125,8 +134,23 @@ static inline idx_t ReadInteger(const char *ptr, const char *const end, idx_t &i static inline JSONKeyReadResult ReadKey(const char *ptr, const char *const end) { D_ASSERT(ptr != end); if (*ptr == '*') { // Wildcard + if (*(ptr + 1) == '*') { + return JSONKeyReadResult::RecWildCard(); + } return JSONKeyReadResult::WildCard(); } + bool recursive = false; + if (*ptr == '.') { + char next = *(ptr + 1); + if (next == '*') { + return JSONKeyReadResult::RecWildCard(); + } + if (next == '[') { + return JSONKeyReadResult::RecWildCardShortcut(); + } + ptr++; + recursive = true; + } bool escaped = false; if (*ptr == '"') { ptr++; // Skip past opening '"' @@ -139,6 +163,10 @@ static inline JSONKeyReadResult ReadKey(const char *ptr, const char *const end) if (escaped) { result.chars_read += 2; // Account for surrounding quotes } + if (recursive) { + result.chars_read += 1; + result.recursive = true; + } return result; } @@ -197,7 +225,7 @@ JSONPathType JSONCommon::ValidatePath(const char *ptr, const idx_t &len, const b auto key = ReadKey(ptr, end); if (!key.IsValid()) { ThrowPathError(ptr, end, binder); - } else if (key.IsWildCard()) { + } else if (key.IsWildCard() || key.recursive) { path_type = JSONPathType::WILDCARD; } ptr += key.chars_read; @@ -272,12 +300,39 @@ void GetWildcardPathInternal(yyjson_val *val, const char *ptr, const char *const D_ASSERT(ptr != end); switch (c) { case '.': { // Object field - if (!unsafe_yyjson_is_obj(val)) { - return; - } auto key_result = ReadKey(ptr, end); D_ASSERT(key_result.IsValid()); + if (key_result.recursive) { + if (key_result.IsWildCard()) { + ptr += key_result.chars_read; + } + vector rec_vals; + rec_vals.emplace_back(val); + for (idx_t i = 0; i < rec_vals.size(); i++) { + yyjson_val *rec_val = rec_vals[i]; + if (yyjson_is_arr(rec_val)) { + size_t idx, max; + yyjson_val *element; + yyjson_arr_foreach(rec_val, idx, max, element) { + rec_vals.emplace_back(element); + } + } else if (yyjson_is_obj(rec_val)) { + size_t idx, max; + yyjson_val *key, *element; + yyjson_obj_foreach(rec_val, idx, max, key, element) { + rec_vals.emplace_back(element); + } + } + if (i > 0 || ptr != end) { + GetWildcardPathInternal(rec_val, ptr, end, vals); + } + } + return; + } ptr += key_result.chars_read; + if (!unsafe_yyjson_is_obj(val)) { + return; + } if (key_result.IsWildCard()) { // Wildcard size_t idx, max; yyjson_val *key, *obj_val; @@ -325,6 +380,7 @@ void GetWildcardPathInternal(yyjson_val *val, const char *ptr, const char *const if (val != nullptr) { vals.emplace_back(val); } + return; } void JSONCommon::GetWildcardPath(yyjson_val *val, const char *ptr, const idx_t &len, vector &vals) { diff --git a/src/duckdb/extension/json/json_extension.cpp b/src/duckdb/extension/json/json_extension.cpp index 07bba320..b594c26d 100644 --- a/src/duckdb/extension/json/json_extension.cpp +++ b/src/duckdb/extension/json/json_extension.cpp @@ -17,11 +17,19 @@ namespace duckdb { static DefaultMacro json_macros[] = { - {DEFAULT_SCHEMA, "json_group_array", {"x", nullptr}, "to_json(list(x))"}, - {DEFAULT_SCHEMA, "json_group_object", {"name", "value", nullptr}, "to_json(map(list(name), list(value)))"}, - {DEFAULT_SCHEMA, "json_group_structure", {"x", nullptr}, "json_structure(json_group_array(x))->'0'"}, - {DEFAULT_SCHEMA, "json", {"x", nullptr}, "json_extract(x, '$')"}, - {nullptr, nullptr, {nullptr}, nullptr}}; + {DEFAULT_SCHEMA, "json_group_array", {"x", nullptr}, {{nullptr, nullptr}}, "to_json(list(x))"}, + {DEFAULT_SCHEMA, + "json_group_object", + {"name", "value", nullptr}, + {{nullptr, nullptr}}, + "to_json(map(list(name), list(value)))"}, + {DEFAULT_SCHEMA, + "json_group_structure", + {"x", nullptr}, + {{nullptr, nullptr}}, + "json_structure(json_group_array(x))->'0'"}, + {DEFAULT_SCHEMA, "json", {"x", nullptr}, {{nullptr, nullptr}}, "json_extract(x, '$')"}, + {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr}}; void JsonExtension::Load(DuckDB &db) { auto &db_instance = *db.instance; diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index a4c818bc..0ad68376 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -160,13 +160,17 @@ vector JSONFunctions::GetScalarFunctions() { // Other functions.push_back(GetArrayLengthFunction()); functions.push_back(GetContainsFunction()); + functions.push_back(GetExistsFunction()); functions.push_back(GetKeysFunction()); functions.push_back(GetTypeFunction()); functions.push_back(GetValidFunction()); + functions.push_back(GetValueFunction()); functions.push_back(GetSerializePlanFunction()); functions.push_back(GetSerializeSqlFunction()); functions.push_back(GetDeserializeSqlFunction()); + functions.push_back(GetPrettyPrintFunction()); + return functions; } @@ -196,7 +200,7 @@ vector JSONFunctions::GetTableFunctions() { unique_ptr JSONFunctions::ReadJSONReplacement(ClientContext &context, ReplacementScanInput &input, optional_ptr data) { - auto &table_name = input.table_name; + auto table_name = ReplacementScan::GetFullPath(input); if (!ReplacementScan::CanReplace(table_name, {"json", "jsonl", "ndjson"})) { return nullptr; } diff --git a/src/duckdb/extension/json/json_functions/json_array_length.cpp b/src/duckdb/extension/json/json_functions/json_array_length.cpp index fc33f371..c487239b 100644 --- a/src/duckdb/extension/json/json_functions/json_array_length.cpp +++ b/src/duckdb/extension/json/json_functions/json_array_length.cpp @@ -2,7 +2,7 @@ namespace duckdb { -static inline uint64_t GetArrayLength(yyjson_val *val, yyjson_alc *alc, Vector &result) { +static inline uint64_t GetArrayLength(yyjson_val *val, yyjson_alc *, Vector &, ValidityMask &, idx_t) { return yyjson_arr_size(val); } diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index a903cb99..3927daa1 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -223,6 +223,12 @@ struct CreateJSONValue { } }; +template +inline yyjson_mut_val *CreateJSONValueFromJSON(yyjson_mut_doc *doc, const T &value) { + return nullptr; // This function should only be called with string_t as template +} + +template <> inline yyjson_mut_val *CreateJSONValueFromJSON(yyjson_mut_doc *doc, const string_t &value) { auto value_doc = JSONCommon::ReadDocument(value, JSONCommon::READ_FLAG, &doc->alc); auto result = yyjson_val_mut_copy(doc, value_doc->root); @@ -273,7 +279,7 @@ static void TemplatedCreateValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], V if (!value_data.validity.RowIsValid(val_idx)) { vals[i] = yyjson_mut_null(doc); } else if (type_is_json) { - vals[i] = CreateJSONValueFromJSON(doc, (string_t &)values[val_idx]); + vals[i] = CreateJSONValueFromJSON(doc, values[val_idx]); } else { vals[i] = CreateJSONValue::Operation(doc, values[val_idx]); } @@ -544,6 +550,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::VARINT: case LogicalTypeId::UUID: { Vector string_vector(LogicalTypeId::VARCHAR, count); VectorOperations::DefaultCast(value_v, string_vector, count); @@ -556,7 +563,17 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m TemplatedCreateValues(doc, vals, double_vector, count); break; } - default: + case LogicalTypeId::INVALID: + case LogicalTypeId::UNKNOWN: + case LogicalTypeId::ANY: + case LogicalTypeId::USER: + case LogicalTypeId::CHAR: + case LogicalTypeId::STRING_LITERAL: + case LogicalTypeId::INTEGER_LITERAL: + case LogicalTypeId::POINTER: + case LogicalTypeId::VALIDITY: + case LogicalTypeId::TABLE: + case LogicalTypeId::LAMBDA: throw InternalException("Unsupported type arrived at JSON create function"); } } @@ -647,7 +664,7 @@ static void ToJSONFunctionInternal(const StructNames &names, Vector &input, cons } } - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR || count == 1) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } } @@ -749,7 +766,7 @@ void JSONFunctions::RegisterJSONCreateCastFunctions(CastFunctionSet &casts) { source_type = LogicalType::UNION({{"any", LogicalType::ANY}}); break; case LogicalTypeId::ARRAY: - source_type = LogicalType::ARRAY(LogicalType::ANY); + source_type = LogicalType::ARRAY(LogicalType::ANY, optional_idx()); break; case LogicalTypeId::VARCHAR: // We skip this one here as it's handled in json_functions.cpp diff --git a/src/duckdb/extension/json/json_functions/json_exists.cpp b/src/duckdb/extension/json/json_functions/json_exists.cpp new file mode 100644 index 00000000..f9d3548b --- /dev/null +++ b/src/duckdb/extension/json/json_functions/json_exists.cpp @@ -0,0 +1,32 @@ +#include "json_executors.hpp" + +namespace duckdb { + +static inline bool JSONExists(yyjson_val *val, yyjson_alc *, Vector &, ValidityMask &, idx_t) { + return val; +} + +static void BinaryExistsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + JSONExecutors::BinaryExecute(args, state, result, JSONExists); +} + +static void ManyExistsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + JSONExecutors::ExecuteMany(args, state, result, JSONExists); +} + +static void GetExistsFunctionsInternal(ScalarFunctionSet &set, const LogicalType &input_type) { + set.AddFunction(ScalarFunction({input_type, LogicalType::VARCHAR}, LogicalType::BOOLEAN, BinaryExistsFunction, + JSONReadFunctionData::Bind, nullptr, nullptr, JSONFunctionLocalState::Init)); + set.AddFunction(ScalarFunction({input_type, LogicalType::LIST(LogicalType::VARCHAR)}, + LogicalType::LIST(LogicalType::BOOLEAN), ManyExistsFunction, + JSONReadManyFunctionData::Bind, nullptr, nullptr, JSONFunctionLocalState::Init)); +} + +ScalarFunctionSet JSONFunctions::GetExistsFunction() { + ScalarFunctionSet set("json_exists"); + GetExistsFunctionsInternal(set, LogicalType::VARCHAR); + GetExistsFunctionsInternal(set, LogicalType::JSON()); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/json/json_functions/json_extract.cpp b/src/duckdb/extension/json/json_functions/json_extract.cpp index 59daa49c..2fc32612 100644 --- a/src/duckdb/extension/json/json_functions/json_extract.cpp +++ b/src/duckdb/extension/json/json_functions/json_extract.cpp @@ -2,11 +2,11 @@ namespace duckdb { -static inline string_t ExtractFromVal(yyjson_val *val, yyjson_alc *alc, Vector &) { +static inline string_t ExtractFromVal(yyjson_val *val, yyjson_alc *alc, Vector &, ValidityMask &, idx_t) { return JSONCommon::WriteVal(val, alc); } -static inline string_t ExtractStringFromVal(yyjson_val *val, yyjson_alc *alc, Vector &) { +static inline string_t ExtractStringFromVal(yyjson_val *val, yyjson_alc *alc, Vector &, ValidityMask &, idx_t) { return yyjson_is_str(val) ? string_t(unsafe_yyjson_get_str(val), unsafe_yyjson_get_len(val)) : JSONCommon::WriteVal(val, alc); } diff --git a/src/duckdb/extension/json/json_functions/json_keys.cpp b/src/duckdb/extension/json/json_functions/json_keys.cpp index eb991883..0b672c08 100644 --- a/src/duckdb/extension/json/json_functions/json_keys.cpp +++ b/src/duckdb/extension/json/json_functions/json_keys.cpp @@ -2,7 +2,7 @@ namespace duckdb { -static inline list_entry_t GetJSONKeys(yyjson_val *val, yyjson_alc *alc, Vector &result) { +static inline list_entry_t GetJSONKeys(yyjson_val *val, yyjson_alc *, Vector &result, ValidityMask &, idx_t) { auto num_keys = yyjson_obj_size(val); auto current_size = ListVector::GetListSize(result); auto new_size = current_size + num_keys; diff --git a/src/duckdb/extension/json/json_functions/json_pretty.cpp b/src/duckdb/extension/json/json_functions/json_pretty.cpp new file mode 100644 index 00000000..1fb96081 --- /dev/null +++ b/src/duckdb/extension/json/json_functions/json_pretty.cpp @@ -0,0 +1,32 @@ +#include "json_executors.hpp" + +namespace duckdb { + +//! Pretty Print a given JSON Document +string_t PrettyPrint(yyjson_val *val, yyjson_alc *alc, Vector &, ValidityMask &, idx_t) { + D_ASSERT(alc); + idx_t len; + auto data = + yyjson_val_write_opts(val, JSONCommon::WRITE_PRETTY_FLAG, alc, reinterpret_cast(&len), nullptr); + return string_t(data, len); +} + +static void PrettyPrintFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto json_type = args.data[0].GetType(); + D_ASSERT(json_type == LogicalType::VARCHAR || json_type == LogicalType::JSON()); + + JSONExecutors::UnaryExecute(args, state, result, PrettyPrint); +} + +static void GetPrettyPrintFunctionInternal(ScalarFunctionSet &set, const LogicalType &json) { + set.AddFunction(ScalarFunction("json_pretty", {json}, LogicalType::VARCHAR, PrettyPrintFunction, nullptr, nullptr, + nullptr, JSONFunctionLocalState::Init)); +} + +ScalarFunctionSet JSONFunctions::GetPrettyPrintFunction() { + ScalarFunctionSet set("json_pretty"); + GetPrettyPrintFunctionInternal(set, LogicalType::JSON()); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp b/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp index 17fd3336..2873eba3 100644 --- a/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp +++ b/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp @@ -196,7 +196,11 @@ static unique_ptr DeserializeSelectStatement(string_t input, yy } auto stmt_json = yyjson_arr_get(statements, 0); JsonDeserializer deserializer(stmt_json, doc); - return SelectStatement::Deserialize(deserializer); + auto stmt = SelectStatement::Deserialize(deserializer); + if (!stmt->node) { + throw ParserException("Error parsing json: no select node found in json"); + } + return stmt; } //---------------------------------------------------------------------- diff --git a/src/duckdb/extension/json/json_functions/json_structure.cpp b/src/duckdb/extension/json/json_functions/json_structure.cpp index e6fb3456..04800572 100644 --- a/src/duckdb/extension/json/json_functions/json_structure.cpp +++ b/src/duckdb/extension/json/json_functions/json_structure.cpp @@ -5,13 +5,15 @@ #include "json_scan.hpp" #include "json_transform.hpp" +#include + namespace duckdb { -static inline bool IsNumeric(LogicalTypeId type) { +static bool IsNumeric(LogicalTypeId type) { return type == LogicalTypeId::DOUBLE || type == LogicalTypeId::UBIGINT || type == LogicalTypeId::BIGINT; } -static inline LogicalTypeId MaxNumericType(LogicalTypeId &a, LogicalTypeId &b) { +static LogicalTypeId MaxNumericType(const LogicalTypeId &a, const LogicalTypeId &b) { D_ASSERT(a != b); if (a == LogicalTypeId::DOUBLE || b == LogicalTypeId::DOUBLE) { return LogicalTypeId::DOUBLE; @@ -19,31 +21,36 @@ static inline LogicalTypeId MaxNumericType(LogicalTypeId &a, LogicalTypeId &b) { return LogicalTypeId::BIGINT; } -JSONStructureNode::JSONStructureNode() : initialized(false), count(0) { +JSONStructureNode::JSONStructureNode() : count(0), null_count(0) { +} + +JSONStructureNode::JSONStructureNode(const char *key_ptr, const size_t key_len) : JSONStructureNode() { + key = make_uniq(key_ptr, key_len); } JSONStructureNode::JSONStructureNode(yyjson_val *key_p, yyjson_val *val_p, const bool ignore_errors) - : key(make_uniq(unsafe_yyjson_get_str(key_p), unsafe_yyjson_get_len(key_p))), initialized(false), count(0) { - D_ASSERT(yyjson_is_str(key_p)); + : JSONStructureNode(unsafe_yyjson_get_str(key_p), unsafe_yyjson_get_len(key_p)) { JSONStructure::ExtractStructure(val_p, *this, ignore_errors); } +static void SwapJSONStructureNode(JSONStructureNode &a, JSONStructureNode &b) noexcept { + std::swap(a.key, b.key); + std::swap(a.initialized, b.initialized); + std::swap(a.descriptions, b.descriptions); + std::swap(a.count, b.count); + std::swap(a.null_count, b.null_count); +} + JSONStructureNode::JSONStructureNode(JSONStructureNode &&other) noexcept { - std::swap(key, other.key); - std::swap(initialized, other.initialized); - std::swap(descriptions, other.descriptions); - std::swap(count, other.count); + SwapJSONStructureNode(*this, other); } JSONStructureNode &JSONStructureNode::operator=(JSONStructureNode &&other) noexcept { - std::swap(key, other.key); - std::swap(initialized, other.initialized); - std::swap(descriptions, other.descriptions); - std::swap(count, other.count); + SwapJSONStructureNode(*this, other); return *this; } -JSONStructureDescription &JSONStructureNode::GetOrCreateDescription(LogicalTypeId type) { +JSONStructureDescription &JSONStructureNode::GetOrCreateDescription(const LogicalTypeId type) { if (descriptions.empty()) { // Empty, just put this type in there descriptions.emplace_back(type); @@ -66,7 +73,8 @@ JSONStructureDescription &JSONStructureNode::GetOrCreateDescription(LogicalTypeI for (auto &description : descriptions) { if (type == description.type) { return description; - } else if (is_numeric && IsNumeric(description.type)) { + } + if (is_numeric && IsNumeric(description.type)) { description.type = MaxNumericType(type, description.type); return description; } @@ -95,7 +103,7 @@ bool JSONStructureNode::ContainsVarchar() const { } void JSONStructureNode::InitializeCandidateTypes(const idx_t max_depth, const bool convert_strings_to_integers, - idx_t depth) { + const idx_t depth) { if (depth >= max_depth) { return; } @@ -113,14 +121,15 @@ void JSONStructureNode::InitializeCandidateTypes(const idx_t max_depth, const bo description.candidate_types = {LogicalTypeId::UUID, LogicalTypeId::TIMESTAMP, LogicalTypeId::DATE, LogicalTypeId::TIME}; } - } - initialized = true; - for (auto &child : description.children) { - child.InitializeCandidateTypes(max_depth, convert_strings_to_integers, depth + 1); + initialized = true; + } else { + for (auto &child : description.children) { + child.InitializeCandidateTypes(max_depth, convert_strings_to_integers, depth + 1); + } } } -void JSONStructureNode::RefineCandidateTypes(yyjson_val *vals[], idx_t val_count, Vector &string_vector, +void JSONStructureNode::RefineCandidateTypes(yyjson_val *vals[], const idx_t val_count, Vector &string_vector, ArenaAllocator &allocator, DateFormatMap &date_format_map) { if (descriptions.size() != 1) { // We can't refine types if we have more than 1 description (yet), defaults to JSON type for now @@ -142,7 +151,7 @@ void JSONStructureNode::RefineCandidateTypes(yyjson_val *vals[], idx_t val_count } } -void JSONStructureNode::RefineCandidateTypesArray(yyjson_val *vals[], idx_t val_count, Vector &string_vector, +void JSONStructureNode::RefineCandidateTypesArray(yyjson_val *vals[], const idx_t val_count, Vector &string_vector, ArenaAllocator &allocator, DateFormatMap &date_format_map) { D_ASSERT(descriptions.size() == 1 && descriptions[0].type == LogicalTypeId::LIST); auto &desc = descriptions[0]; @@ -173,7 +182,7 @@ void JSONStructureNode::RefineCandidateTypesArray(yyjson_val *vals[], idx_t val_ child.RefineCandidateTypes(child_vals, total_list_size, string_vector, allocator, date_format_map); } -void JSONStructureNode::RefineCandidateTypesObject(yyjson_val *vals[], idx_t val_count, Vector &string_vector, +void JSONStructureNode::RefineCandidateTypesObject(yyjson_val *vals[], const idx_t val_count, Vector &string_vector, ArenaAllocator &allocator, DateFormatMap &date_format_map) { D_ASSERT(descriptions.size() == 1 && descriptions[0].type == LogicalTypeId::STRUCT); auto &desc = descriptions[0]; @@ -186,22 +195,21 @@ void JSONStructureNode::RefineCandidateTypesObject(yyjson_val *vals[], idx_t val reinterpret_cast(allocator.AllocateAligned(val_count * sizeof(yyjson_val *)))); } - idx_t found_key_count; - auto found_keys = reinterpret_cast(allocator.AllocateAligned(sizeof(bool) * child_count)); + const auto found_keys = reinterpret_cast(allocator.AllocateAligned(sizeof(bool) * child_count)); const auto &key_map = desc.key_map; size_t idx, max; yyjson_val *child_key, *child_val; for (idx_t i = 0; i < val_count; i++) { if (vals[i] && !unsafe_yyjson_is_null(vals[i])) { - found_key_count = 0; + idx_t found_key_count = 0; memset(found_keys, false, child_count); D_ASSERT(yyjson_is_obj(vals[i])); yyjson_obj_foreach(vals[i], idx, max, child_key, child_val) { D_ASSERT(yyjson_is_str(child_key)); - auto key_ptr = unsafe_yyjson_get_str(child_key); - auto key_len = unsafe_yyjson_get_len(child_key); + const auto key_ptr = unsafe_yyjson_get_str(child_key); + const auto key_len = unsafe_yyjson_get_len(child_key); auto it = key_map.find({key_ptr, key_len}); D_ASSERT(it != key_map.end()); const auto child_idx = it->second; @@ -231,7 +239,7 @@ void JSONStructureNode::RefineCandidateTypesObject(yyjson_val *vals[], idx_t val } } -void JSONStructureNode::RefineCandidateTypesString(yyjson_val *vals[], idx_t val_count, Vector &string_vector, +void JSONStructureNode::RefineCandidateTypesString(yyjson_val *vals[], const idx_t val_count, Vector &string_vector, DateFormatMap &date_format_map) { D_ASSERT(descriptions.size() == 1 && descriptions[0].type == LogicalTypeId::VARCHAR); if (descriptions[0].candidate_types.empty()) { @@ -242,7 +250,7 @@ void JSONStructureNode::RefineCandidateTypesString(yyjson_val *vals[], idx_t val EliminateCandidateTypes(val_count, string_vector, date_format_map); } -void JSONStructureNode::EliminateCandidateTypes(idx_t vec_count, Vector &string_vector, +void JSONStructureNode::EliminateCandidateTypes(const idx_t vec_count, Vector &string_vector, DateFormatMap &date_format_map) { D_ASSERT(descriptions.size() == 1 && descriptions[0].type == LogicalTypeId::VARCHAR); auto &description = descriptions[0]; @@ -296,12 +304,12 @@ bool TryParse(Vector &string_vector, StrpTimeFormat &format, const idx_t count) return true; } -bool JSONStructureNode::EliminateCandidateFormats(idx_t vec_count, Vector &string_vector, Vector &result_vector, - vector &formats) { +bool JSONStructureNode::EliminateCandidateFormats(const idx_t vec_count, Vector &string_vector, + const Vector &result_vector, vector &formats) { D_ASSERT(descriptions.size() == 1 && descriptions[0].type == LogicalTypeId::VARCHAR); const auto type = result_vector.GetType().id(); for (idx_t i = formats.size(); i != 0; i--) { - idx_t actual_index = i - 1; + const idx_t actual_index = i - 1; auto &format = formats[actual_index]; bool success; switch (type) { @@ -324,21 +332,22 @@ bool JSONStructureNode::EliminateCandidateFormats(idx_t vec_count, Vector &strin return false; } -JSONStructureDescription::JSONStructureDescription(LogicalTypeId type_p) : type(type_p) { +JSONStructureDescription::JSONStructureDescription(const LogicalTypeId type_p) : type(type_p) { +} + +static void SwapJSONStructureDescription(JSONStructureDescription &a, JSONStructureDescription &b) noexcept { + std::swap(a.type, b.type); + std::swap(a.key_map, b.key_map); + std::swap(a.children, b.children); + std::swap(a.candidate_types, b.candidate_types); } JSONStructureDescription::JSONStructureDescription(JSONStructureDescription &&other) noexcept { - std::swap(type, other.type); - std::swap(key_map, other.key_map); - std::swap(children, other.children); - std::swap(candidate_types, other.candidate_types); + SwapJSONStructureDescription(*this, other); } JSONStructureDescription &JSONStructureDescription::operator=(JSONStructureDescription &&other) noexcept { - std::swap(type, other.type); - std::swap(key_map, other.key_map); - std::swap(children, other.children); - std::swap(candidate_types, other.candidate_types); + SwapJSONStructureDescription(*this, other); return *this; } @@ -351,27 +360,31 @@ JSONStructureNode &JSONStructureDescription::GetOrCreateChild() { return children.back(); } +JSONStructureNode &JSONStructureDescription::GetOrCreateChild(const char *key_ptr, const size_t key_size) { + // Check if there is already a child with the same key + const JSONKey temp_key {key_ptr, key_size}; + const auto it = key_map.find(temp_key); + if (it != key_map.end()) { + return children[it->second]; // Found it + } + + // Didn't find, create a new child + children.emplace_back(key_ptr, key_size); + const auto &persistent_key_string = *children.back().key; + JSONKey new_key {persistent_key_string.c_str(), persistent_key_string.length()}; + key_map.emplace(new_key, children.size() - 1); + return children.back(); +} + JSONStructureNode &JSONStructureDescription::GetOrCreateChild(yyjson_val *key, yyjson_val *val, const bool ignore_errors) { D_ASSERT(yyjson_is_str(key)); - // Check if there is already a child with the same key - idx_t child_idx; - JSONKey temp_key {unsafe_yyjson_get_str(key), unsafe_yyjson_get_len(key)}; - auto it = key_map.find(temp_key); - if (it == key_map.end()) { // Didn't find, create a new child - child_idx = children.size(); - children.emplace_back(key, val, ignore_errors); - const auto &persistent_key_string = children.back().key; - JSONKey new_key {persistent_key_string->c_str(), persistent_key_string->length()}; - key_map.emplace(new_key, child_idx); - } else { // Found it - child_idx = it->second; - JSONStructure::ExtractStructure(val, children[child_idx], ignore_errors); - } - return children[child_idx]; -} - -static inline void ExtractStructureArray(yyjson_val *arr, JSONStructureNode &node, const bool ignore_errors) { + auto &child = GetOrCreateChild(unsafe_yyjson_get_str(key), unsafe_yyjson_get_len(key)); + JSONStructure::ExtractStructure(val, child, ignore_errors); + return child; +} + +static void ExtractStructureArray(yyjson_val *arr, JSONStructureNode &node, const bool ignore_errors) { D_ASSERT(yyjson_is_arr(arr)); auto &description = node.GetOrCreateDescription(LogicalTypeId::LIST); auto &child = description.GetOrCreateChild(); @@ -383,7 +396,7 @@ static inline void ExtractStructureArray(yyjson_val *arr, JSONStructureNode &nod } } -static inline void ExtractStructureObject(yyjson_val *obj, JSONStructureNode &node, const bool ignore_errors) { +static void ExtractStructureObject(yyjson_val *obj, JSONStructureNode &node, const bool ignore_errors) { D_ASSERT(yyjson_is_obj(obj)); auto &description = node.GetOrCreateDescription(LogicalTypeId::STRUCT); @@ -409,14 +422,19 @@ static inline void ExtractStructureObject(yyjson_val *obj, JSONStructureNode &no } } -static inline void ExtractStructureVal(yyjson_val *val, JSONStructureNode &node) { +static void ExtractStructureVal(yyjson_val *val, JSONStructureNode &node) { D_ASSERT(!yyjson_is_arr(val) && !yyjson_is_obj(val)); node.GetOrCreateDescription(JSONCommon::ValTypeToLogicalTypeId(val)); } void JSONStructure::ExtractStructure(yyjson_val *val, JSONStructureNode &node, const bool ignore_errors) { node.count++; - switch (yyjson_get_tag(val)) { + const auto tag = yyjson_get_tag(val); + if (tag == (YYJSON_TYPE_NULL | YYJSON_SUBTYPE_NONE)) { + node.null_count++; + } + + switch (tag) { case YYJSON_TYPE_ARR | YYJSON_SUBTYPE_NONE: return ExtractStructureArray(val, node, ignore_errors); case YYJSON_TYPE_OBJ | YYJSON_SUBTYPE_NONE: @@ -433,19 +451,19 @@ JSONStructureNode ExtractStructureInternal(yyjson_val *val, const bool ignore_er } //! Forward declaration for recursion -static inline yyjson_mut_val *ConvertStructure(const JSONStructureNode &node, yyjson_mut_doc *doc); +static yyjson_mut_val *ConvertStructure(const JSONStructureNode &node, yyjson_mut_doc *doc); -static inline yyjson_mut_val *ConvertStructureArray(const JSONStructureNode &node, yyjson_mut_doc *doc) { +static yyjson_mut_val *ConvertStructureArray(const JSONStructureNode &node, yyjson_mut_doc *doc) { D_ASSERT(node.descriptions.size() == 1 && node.descriptions[0].type == LogicalTypeId::LIST); const auto &desc = node.descriptions[0]; D_ASSERT(desc.children.size() == 1); - auto arr = yyjson_mut_arr(doc); + const auto arr = yyjson_mut_arr(doc); yyjson_mut_arr_append(arr, ConvertStructure(desc.children[0], doc)); return arr; } -static inline yyjson_mut_val *ConvertStructureObject(const JSONStructureNode &node, yyjson_mut_doc *doc) { +static yyjson_mut_val *ConvertStructureObject(const JSONStructureNode &node, yyjson_mut_doc *doc) { D_ASSERT(node.descriptions.size() == 1 && node.descriptions[0].type == LogicalTypeId::STRUCT); auto &desc = node.descriptions[0]; if (desc.children.empty()) { @@ -453,7 +471,7 @@ static inline yyjson_mut_val *ConvertStructureObject(const JSONStructureNode &no return yyjson_mut_str(doc, LogicalType::JSON_TYPE_NAME); } - auto obj = yyjson_mut_obj(doc); + const auto obj = yyjson_mut_obj(doc); for (auto &child : desc.children) { D_ASSERT(child.key); yyjson_mut_obj_add(obj, yyjson_mut_strn(doc, child.key->c_str(), child.key->length()), @@ -462,7 +480,7 @@ static inline yyjson_mut_val *ConvertStructureObject(const JSONStructureNode &no return obj; } -static inline yyjson_mut_val *ConvertStructure(const JSONStructureNode &node, yyjson_mut_doc *doc) { +static yyjson_mut_val *ConvertStructure(const JSONStructureNode &node, yyjson_mut_doc *doc) { if (node.descriptions.empty()) { return yyjson_mut_str(doc, JSONCommon::TYPE_STRING_NULL); } @@ -481,7 +499,7 @@ static inline yyjson_mut_val *ConvertStructure(const JSONStructureNode &node, yy } } -static inline string_t JSONStructureFunction(yyjson_val *val, yyjson_alc *alc, Vector &) { +static string_t JSONStructureFunction(yyjson_val *val, yyjson_alc *alc, Vector &, ValidityMask &, idx_t) { return JSONCommon::WriteVal( ConvertStructure(ExtractStructureInternal(val, true), yyjson_mut_doc_new(alc)), alc); } @@ -503,45 +521,236 @@ ScalarFunctionSet JSONFunctions::GetStructureFunction() { } static LogicalType StructureToTypeArray(ClientContext &context, const JSONStructureNode &node, const idx_t max_depth, - const double field_appearance_threshold, idx_t depth) { + const double field_appearance_threshold, const idx_t map_inference_threshold, + const idx_t depth, const LogicalType &null_type) { D_ASSERT(node.descriptions.size() == 1 && node.descriptions[0].type == LogicalTypeId::LIST); const auto &desc = node.descriptions[0]; D_ASSERT(desc.children.size() == 1); - return LogicalType::LIST(JSONStructure::StructureToType( - context, desc.children[0], max_depth, field_appearance_threshold, depth + 1, desc.children[0].count)); + return LogicalType::LIST(JSONStructure::StructureToType(context, desc.children[0], max_depth, + field_appearance_threshold, map_inference_threshold, + depth + 1, null_type)); +} + +static void MergeNodes(JSONStructureNode &merged, const JSONStructureNode &node); + +static void MergeNodeArray(JSONStructureNode &merged, const JSONStructureDescription &child_desc) { + D_ASSERT(child_desc.type == LogicalTypeId::LIST); + auto &merged_desc = merged.GetOrCreateDescription(LogicalTypeId::LIST); + auto &merged_child = merged_desc.GetOrCreateChild(); + for (auto &list_child : child_desc.children) { + MergeNodes(merged_child, list_child); + } +} + +static void MergeNodeObject(JSONStructureNode &merged, const JSONStructureDescription &child_desc) { + D_ASSERT(child_desc.type == LogicalTypeId::STRUCT); + auto &merged_desc = merged.GetOrCreateDescription(LogicalTypeId::STRUCT); + for (auto &struct_child : child_desc.children) { + const auto &struct_child_key = *struct_child.key; + auto &merged_child = merged_desc.GetOrCreateChild(struct_child_key.c_str(), struct_child_key.length()); + MergeNodes(merged_child, struct_child); + } +} + +static void MergeNodeVal(JSONStructureNode &merged, const JSONStructureDescription &child_desc, + const bool node_initialized) { + D_ASSERT(child_desc.type != LogicalTypeId::LIST && child_desc.type != LogicalTypeId::STRUCT); + auto &merged_desc = merged.GetOrCreateDescription(child_desc.type); + if (merged_desc.type != LogicalTypeId::VARCHAR || !node_initialized || merged.descriptions.size() != 1) { + return; + } + if (!merged.initialized) { + merged_desc.candidate_types = child_desc.candidate_types; + } else if (!merged_desc.candidate_types.empty() && !child_desc.candidate_types.empty() && + merged_desc.candidate_types.back() != child_desc.candidate_types.back()) { + merged_desc.candidate_types.clear(); // Not the same, default to VARCHAR + } + merged.initialized = true; +} + +static void MergeNodes(JSONStructureNode &merged, const JSONStructureNode &node) { + merged.count += node.count; + merged.null_count += node.null_count; + for (const auto &child_desc : node.descriptions) { + switch (child_desc.type) { + case LogicalTypeId::LIST: + MergeNodeArray(merged, child_desc); + break; + case LogicalTypeId::STRUCT: + MergeNodeObject(merged, child_desc); + break; + default: + MergeNodeVal(merged, child_desc, node.initialized); + break; + } + } +} + +static double CalculateTypeSimilarity(const LogicalType &merged, const LogicalType &type, idx_t max_depth, idx_t depth); + +static double CalculateMapAndStructSimilarity(const LogicalType &map_type, const LogicalType &struct_type, + const bool swapped, const idx_t max_depth, const idx_t depth) { + const auto &map_value_type = MapType::ValueType(map_type); + const auto &struct_child_types = StructType::GetChildTypes(struct_type); + double total_similarity = 0; + for (const auto &struct_child_type : struct_child_types) { + const auto similarity = + swapped ? CalculateTypeSimilarity(struct_child_type.second, map_value_type, max_depth, depth + 1) + : CalculateTypeSimilarity(map_value_type, struct_child_type.second, max_depth, depth + 1); + if (similarity < 0) { + return similarity; + } + total_similarity += similarity; + } + return total_similarity / static_cast(struct_child_types.size()); +} + +static double CalculateTypeSimilarity(const LogicalType &merged, const LogicalType &type, const idx_t max_depth, + const idx_t depth) { + if (depth >= max_depth || merged.id() == LogicalTypeId::SQLNULL || type.id() == LogicalTypeId::SQLNULL) { + return 1; + } + if (merged.IsJSONType()) { + // Incompatible types + return -1; + } + if (type.IsJSONType() || merged == type) { + return 1; + } + + switch (merged.id()) { + case LogicalTypeId::STRUCT: { + if (type.id() == LogicalTypeId::MAP) { + // This can happen for empty structs/maps ("{}"), or in rare cases where an inconsistent struct becomes + // consistent when merged, but does not have enough children to be considered a map. + return CalculateMapAndStructSimilarity(type, merged, true, max_depth, depth); + } + + // Only structs can be merged into a struct + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + const auto &merged_child_types = StructType::GetChildTypes(merged); + const auto &type_child_types = StructType::GetChildTypes(type); + + unordered_map merged_child_types_map; + for (const auto &merged_child : merged_child_types) { + merged_child_types_map.emplace(merged_child.first, merged_child.second); + } + + double total_similarity = 0; + for (const auto &type_child_type : type_child_types) { + const auto it = merged_child_types_map.find(type_child_type.first); + if (it == merged_child_types_map.end()) { + return -1; + } + const auto similarity = CalculateTypeSimilarity(it->second, type_child_type.second, max_depth, depth + 1); + if (similarity < 0) { + return similarity; + } + total_similarity += similarity; + } + return total_similarity / static_cast(merged_child_types.size()); + } + case LogicalTypeId::MAP: { + if (type.id() == LogicalTypeId::MAP) { + return CalculateTypeSimilarity(MapType::ValueType(merged), MapType::ValueType(type), max_depth, depth + 1); + } + + // Only maps and structs can be merged into a map + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + return CalculateMapAndStructSimilarity(merged, type, false, max_depth, depth); + } + case LogicalTypeId::LIST: { + // Only lists can be merged into a list + D_ASSERT(type.id() == LogicalTypeId::LIST); + const auto &merged_child_type = ListType::GetChildType(merged); + const auto &type_child_type = ListType::GetChildType(type); + return CalculateTypeSimilarity(merged_child_type, type_child_type, max_depth, depth + 1); + } + default: + // This is only reachable if type has been inferred using candidate_types, but candidate_types were not + // consistent among all map values + return 1; + } +} + +static bool IsStructureInconsistent(const JSONStructureDescription &desc, const idx_t sample_count, + const idx_t null_count, const double field_appearance_threshold) { + D_ASSERT(sample_count > null_count); + double total_child_counts = 0; + for (const auto &child : desc.children) { + total_child_counts += static_cast(child.count) / static_cast(sample_count - null_count); + } + const auto avg_occurrence = total_child_counts / static_cast(desc.children.size()); + return avg_occurrence < field_appearance_threshold; +} + +static LogicalType GetMergedType(ClientContext &context, const JSONStructureNode &node, const idx_t max_depth, + const double field_appearance_threshold, const idx_t map_inference_threshold, + const idx_t depth, const LogicalType &null_type) { + D_ASSERT(node.descriptions.size() == 1); + auto &desc = node.descriptions[0]; + JSONStructureNode merged; + for (const auto &child : desc.children) { + MergeNodes(merged, child); + } + return JSONStructure::StructureToType(context, merged, max_depth, field_appearance_threshold, + map_inference_threshold, depth + 1, null_type); } static LogicalType StructureToTypeObject(ClientContext &context, const JSONStructureNode &node, const idx_t max_depth, - const double field_appearance_threshold, idx_t depth, - const idx_t sample_count) { + const double field_appearance_threshold, const idx_t map_inference_threshold, + const idx_t depth, const LogicalType &null_type) { D_ASSERT(node.descriptions.size() == 1 && node.descriptions[0].type == LogicalTypeId::STRUCT); auto &desc = node.descriptions[0]; - // If it's an empty struct we do JSON instead + // If it's an empty struct we do MAP of JSON instead if (desc.children.empty()) { - // Empty struct - let's do JSON instead - return LogicalType::JSON(); + // Empty struct - let's do MAP of JSON instead + return LogicalType::MAP(LogicalType::VARCHAR, null_type); } - // If it's an inconsistent object we also just do JSON - double total_child_counts = 0; - for (const auto &child : desc.children) { - total_child_counts += double(child.count) / sample_count; - } - const auto avg_occurrence = total_child_counts / desc.children.size(); - if (avg_occurrence < field_appearance_threshold) { - return LogicalType::JSON(); + // If it's an inconsistent object we also just do MAP with the best-possible, recursively-merged value type + if (IsStructureInconsistent(desc, node.count, node.null_count, field_appearance_threshold)) { + return LogicalType::MAP(LogicalType::VARCHAR, + GetMergedType(context, node, max_depth, field_appearance_threshold, + map_inference_threshold, depth + 1, null_type)); } + // We have a consistent object child_list_t child_types; child_types.reserve(desc.children.size()); for (auto &child : desc.children) { D_ASSERT(child.key); child_types.emplace_back(*child.key, JSONStructure::StructureToType(context, child, max_depth, field_appearance_threshold, - depth + 1, sample_count)); + map_inference_threshold, depth + 1, null_type)); } + + // If we have many children and all children have similar-enough types we infer map + if (desc.children.size() >= map_inference_threshold) { + LogicalType map_value_type = GetMergedType(context, node, max_depth, field_appearance_threshold, + map_inference_threshold, depth + 1, LogicalTypeId::SQLNULL); + + double total_similarity = 0; + for (const auto &child_type : child_types) { + const auto similarity = CalculateTypeSimilarity(map_value_type, child_type.second, max_depth, depth + 1); + if (similarity < 0) { + total_similarity = similarity; + break; + } + total_similarity += similarity; + } + const auto avg_similarity = total_similarity / static_cast(child_types.size()); + if (avg_similarity >= 0.8) { + if (null_type != LogicalTypeId::SQLNULL) { + map_value_type = GetMergedType(context, node, max_depth, field_appearance_threshold, + map_inference_threshold, depth + 1, null_type); + } + return LogicalType::MAP(LogicalType::VARCHAR, map_value_type); + } + } + return LogicalType::STRUCT(child_types); } @@ -555,30 +764,32 @@ static LogicalType StructureToTypeString(const JSONStructureNode &node) { } LogicalType JSONStructure::StructureToType(ClientContext &context, const JSONStructureNode &node, const idx_t max_depth, - const double field_appearance_threshold, idx_t depth, idx_t sample_count) { + const double field_appearance_threshold, const idx_t map_inference_threshold, + const idx_t depth, const LogicalType &null_type) { if (depth >= max_depth) { return LogicalType::JSON(); } if (node.descriptions.empty()) { - return LogicalType::JSON(); + return null_type; } if (node.descriptions.size() != 1) { // Inconsistent types, so we resort to JSON return LogicalType::JSON(); } - sample_count = sample_count == DConstants::INVALID_INDEX ? node.count : sample_count; auto &desc = node.descriptions[0]; D_ASSERT(desc.type != LogicalTypeId::INVALID); switch (desc.type) { case LogicalTypeId::LIST: - return StructureToTypeArray(context, node, max_depth, field_appearance_threshold, depth); + return StructureToTypeArray(context, node, max_depth, field_appearance_threshold, map_inference_threshold, + depth, null_type); case LogicalTypeId::STRUCT: - return StructureToTypeObject(context, node, max_depth, field_appearance_threshold, depth, sample_count); + return StructureToTypeObject(context, node, max_depth, field_appearance_threshold, map_inference_threshold, + depth, null_type); case LogicalTypeId::VARCHAR: return StructureToTypeString(node); - case LogicalTypeId::SQLNULL: - return LogicalType::JSON(); case LogicalTypeId::UBIGINT: return LogicalTypeId::BIGINT; // We prefer not to return UBIGINT in our type auto-detection + case LogicalTypeId::SQLNULL: + return null_type; default: return desc.type; } diff --git a/src/duckdb/extension/json/json_functions/json_transform.cpp b/src/duckdb/extension/json/json_functions/json_transform.cpp index e6b724f4..d1f1caa1 100644 --- a/src/duckdb/extension/json/json_functions/json_transform.cpp +++ b/src/duckdb/extension/json/json_functions/json_transform.cpp @@ -1012,7 +1012,7 @@ void JSONFunctions::RegisterJSONTransformCastFunctions(CastFunctionSet &casts) { target_type = LogicalType::UNION({{"any", LogicalType::ANY}}); break; case LogicalTypeId::ARRAY: - target_type = LogicalType::ARRAY(LogicalType::ANY); + target_type = LogicalType::ARRAY(LogicalType::ANY, optional_idx()); break; case LogicalTypeId::VARCHAR: // We skip this one here as it's handled in json_functions.cpp diff --git a/src/duckdb/extension/json/json_functions/json_type.cpp b/src/duckdb/extension/json/json_functions/json_type.cpp index 8f3fb3ad..47aec2d6 100644 --- a/src/duckdb/extension/json/json_functions/json_type.cpp +++ b/src/duckdb/extension/json/json_functions/json_type.cpp @@ -2,7 +2,7 @@ namespace duckdb { -static inline string_t GetType(yyjson_val *val, yyjson_alc *alc, Vector &result) { +static inline string_t GetType(yyjson_val *val, yyjson_alc *, Vector &, ValidityMask &mask, idx_t idx) { return JSONCommon::ValTypeToStringT(val); } @@ -11,11 +11,11 @@ static void UnaryTypeFunction(DataChunk &args, ExpressionState &state, Vector &r } static void BinaryTypeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - JSONExecutors::BinaryExecute(args, state, result, GetType); + JSONExecutors::BinaryExecute(args, state, result, GetType); } static void ManyTypeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - JSONExecutors::ExecuteMany(args, state, result, GetType); + JSONExecutors::ExecuteMany(args, state, result, GetType); } static void GetTypeFunctionsInternal(ScalarFunctionSet &set, const LogicalType &input_type) { diff --git a/src/duckdb/extension/json/json_functions/json_value.cpp b/src/duckdb/extension/json/json_functions/json_value.cpp new file mode 100644 index 00000000..06afbd94 --- /dev/null +++ b/src/duckdb/extension/json/json_functions/json_value.cpp @@ -0,0 +1,42 @@ +#include "json_executors.hpp" + +namespace duckdb { + +static inline string_t ValueFromVal(yyjson_val *val, yyjson_alc *alc, Vector &, ValidityMask &mask, idx_t idx) { + switch (yyjson_get_tag(val)) { + case YYJSON_TYPE_ARR | YYJSON_SUBTYPE_NONE: + case YYJSON_TYPE_OBJ | YYJSON_SUBTYPE_NONE: + mask.SetInvalid(idx); + return string_t {}; + default: + return JSONCommon::WriteVal(val, alc); + } +} + +static void ValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + JSONExecutors::BinaryExecute(args, state, result, ValueFromVal); +} + +static void ValueManyFunction(DataChunk &args, ExpressionState &state, Vector &result) { + JSONExecutors::ExecuteMany(args, state, result, ValueFromVal); +} + +static void GetValueFunctionsInternal(ScalarFunctionSet &set, const LogicalType &input_type) { + set.AddFunction(ScalarFunction({input_type, LogicalType::BIGINT}, LogicalType::JSON(), ValueFunction, + JSONReadFunctionData::Bind, nullptr, nullptr, JSONFunctionLocalState::Init)); + set.AddFunction(ScalarFunction({input_type, LogicalType::VARCHAR}, LogicalType::JSON(), ValueFunction, + JSONReadFunctionData::Bind, nullptr, nullptr, JSONFunctionLocalState::Init)); + set.AddFunction(ScalarFunction({input_type, LogicalType::LIST(LogicalType::VARCHAR)}, + LogicalType::LIST(LogicalType::JSON()), ValueManyFunction, + JSONReadManyFunctionData::Bind, nullptr, nullptr, JSONFunctionLocalState::Init)); +} + +ScalarFunctionSet JSONFunctions::GetValueFunction() { + // The value function is just like the extract function but returns NULL if the JSON is not a scalar value + ScalarFunctionSet set("json_value"); + GetValueFunctionsInternal(set, LogicalType::VARCHAR); + GetValueFunctionsInternal(set, LogicalType::JSON()); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/json/json_functions/read_json.cpp b/src/duckdb/extension/json/json_functions/read_json.cpp index 9dc6bfde..d56b3e98 100644 --- a/src/duckdb/extension/json/json_functions/read_json.cpp +++ b/src/duckdb/extension/json/json_functions/read_json.cpp @@ -99,8 +99,8 @@ void JSONScan::AutoDetect(ClientContext &context, JSONScanData &bind_data, vecto bind_data.type = JSONScanType::READ_JSON; // Convert structure to logical type - auto type = - JSONStructure::StructureToType(context, node, bind_data.max_depth, bind_data.field_appearance_threshold); + auto type = JSONStructure::StructureToType(context, node, bind_data.max_depth, bind_data.field_appearance_threshold, + bind_data.map_inference_threshold); // Auto-detect record type if (bind_data.options.record_type == JSONRecordType::AUTO_DETECT) { @@ -145,6 +145,9 @@ unique_ptr ReadJSONBind(ClientContext &context, TableFunctionBindI bind_data->Bind(context, input); for (auto &kv : input.named_parameters) { + if (kv.second.IsNull()) { + throw BinderException("Cannot use NULL as function argument"); + } auto loption = StringUtil::Lower(kv.first); if (kv.second.IsNull()) { throw BinderException("read_json parameter \"%s\" cannot be NULL.", loption); @@ -196,6 +199,16 @@ unique_ptr ReadJSONBind(ClientContext &context, TableFunctionBindI "read_json_auto \"field_appearance_threshold\" parameter must be between 0 and 1"); } bind_data->field_appearance_threshold = arg; + } else if (loption == "map_inference_threshold") { + auto arg = BigIntValue::Get(kv.second); + if (arg == -1) { + bind_data->map_inference_threshold = NumericLimits::Maximum(); + } else if (arg >= 0) { + bind_data->map_inference_threshold = arg; + } else { + throw BinderException("read_json_auto \"map_inference_threshold\" parameter must be 0 or positive, " + "or -1 to disable map inference for consistent objects."); + } } else if (loption == "dateformat" || loption == "date_format") { auto format_string = StringValue::Get(kv.second); if (StringUtil::Lower(format_string) == "iso") { @@ -380,6 +393,7 @@ TableFunctionSet CreateJSONFunctionInfo(string name, shared_ptr in table_function.named_parameters["maximum_depth"] = LogicalType::BIGINT; table_function.named_parameters["field_appearance_threshold"] = LogicalType::DOUBLE; table_function.named_parameters["convert_strings_to_integers"] = LogicalType::BOOLEAN; + table_function.named_parameters["map_inference_threshold"] = LogicalType::BIGINT; return MultiFileReader::CreateFunctionSet(table_function); } diff --git a/src/duckdb/extension/json/json_functions/read_json_objects.cpp b/src/duckdb/extension/json/json_functions/read_json_objects.cpp index 46d4e798..7e97b647 100644 --- a/src/duckdb/extension/json/json_functions/read_json_objects.cpp +++ b/src/duckdb/extension/json/json_functions/read_json_objects.cpp @@ -33,8 +33,9 @@ static void ReadJSONObjectsFunction(ClientContext &context, TableFunctionInput & if (!gstate.names.empty()) { // Create the strings without copying them - auto strings = FlatVector::GetData(output.data[0]); - auto &validity = FlatVector::Validity(output.data[0]); + const auto col_idx = gstate.column_indices[0]; + auto strings = FlatVector::GetData(output.data[col_idx]); + auto &validity = FlatVector::Validity(output.data[col_idx]); for (idx_t i = 0; i < count; i++) { if (objects[i]) { strings[i] = string_t(units[i].pointer, units[i].size); diff --git a/src/duckdb/extension/json/json_scan.cpp b/src/duckdb/extension/json/json_scan.cpp index f406e162..75232a7d 100644 --- a/src/duckdb/extension/json/json_scan.cpp +++ b/src/duckdb/extension/json/json_scan.cpp @@ -29,6 +29,9 @@ void JSONScanData::Bind(ClientContext &context, TableFunctionBindInput &input) { auto_detect = info.auto_detect; for (auto &kv : input.named_parameters) { + if (kv.second.IsNull()) { + throw BinderException("Cannot use NULL as function argument"); + } if (MultiFileReader().ParseOption(kv.first, kv.second, options.file_options, context)) { continue; } @@ -600,20 +603,18 @@ bool JSONScanLocalState::ReadNextBuffer(JSONScanGlobalState &gstate) { // Open the file if it is not yet open if (!current_reader->IsOpen()) { current_reader->OpenJSONFile(); - if (current_reader->GetFileHandle().FileSize() == 0 && !current_reader->GetFileHandle().IsPipe()) { - current_reader->GetFileHandle().Close(); - // Skip over empty files - if (gstate.enable_parallel_scans) { - TryIncrementFileIndex(gstate); - } - continue; - } } // Auto-detect if we haven't yet done this during the bind if (gstate.bind_data.options.record_type == JSONRecordType::AUTO_DETECT || current_reader->GetFormat() == JSONFormat::AUTO_DETECT) { - ReadAndAutoDetect(gstate, buffer, buffer_index); + bool file_done = false; + ReadAndAutoDetect(gstate, buffer, buffer_index, file_done); + if (file_done) { + TryIncrementFileIndex(gstate); + lock_guard reader_guard(current_reader->lock); + current_reader->GetFileHandle().Close(); + } } if (gstate.enable_parallel_scans) { @@ -653,9 +654,8 @@ bool JSONScanLocalState::ReadNextBuffer(JSONScanGlobalState &gstate) { } void JSONScanLocalState::ReadAndAutoDetect(JSONScanGlobalState &gstate, AllocatedData &buffer, - optional_idx &buffer_index) { + optional_idx &buffer_index, bool &file_done) { // We have to detect the JSON format - hold the gstate lock while we do this - bool file_done = false; if (!ReadNextBufferInternal(gstate, buffer, buffer_index, file_done)) { return; } @@ -980,8 +980,9 @@ void JSONScan::ComplexFilterPushdown(ClientContext &context, LogicalGet &get, Fu SimpleMultiFileList file_list(std::move(data.files)); + MultiFilePushdownInfo info(get); auto filtered_list = - MultiFileReader().ComplexFilterPushdown(context, file_list, data.options.file_options, get, filters); + MultiFileReader().ComplexFilterPushdown(context, file_list, data.options.file_options, info, filters); if (filtered_list) { MultiFileReader().PruneReaders(data, *filtered_list); data.files = filtered_list->GetAllFiles(); diff --git a/src/duckdb/extension/json/serialize_json.cpp b/src/duckdb/extension/json/serialize_json.cpp index c45ffc2f..ee8b4368 100644 --- a/src/duckdb/extension/json/serialize_json.cpp +++ b/src/duckdb/extension/json/serialize_json.cpp @@ -44,6 +44,7 @@ void JSONScanData::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(113, "field_appearance_threshold", field_appearance_threshold, 0.1); serializer.WritePropertyWithDefault(114, "maximum_sample_files", maximum_sample_files, 32); serializer.WritePropertyWithDefault(115, "convert_strings_to_integers", convert_strings_to_integers, false); + serializer.WritePropertyWithDefault(116, "map_inference_threshold", map_inference_threshold, 25); } unique_ptr JSONScanData::Deserialize(Deserializer &deserializer) { @@ -71,9 +72,10 @@ unique_ptr JSONScanData::Deserialize(Deserializer &deserializer) { result->max_depth = max_depth; result->transform_options = transform_options; result->names = std::move(names); - deserializer.ReadPropertyWithDefault(113, "field_appearance_threshold", result->field_appearance_threshold, 0.1); - deserializer.ReadPropertyWithDefault(114, "maximum_sample_files", result->maximum_sample_files, 32); - deserializer.ReadPropertyWithDefault(115, "convert_strings_to_integers", result->convert_strings_to_integers, false); + deserializer.ReadPropertyWithExplicitDefault(113, "field_appearance_threshold", result->field_appearance_threshold, 0.1); + deserializer.ReadPropertyWithExplicitDefault(114, "maximum_sample_files", result->maximum_sample_files, 32); + deserializer.ReadPropertyWithExplicitDefault(115, "convert_strings_to_integers", result->convert_strings_to_integers, false); + deserializer.ReadPropertyWithExplicitDefault(116, "map_inference_threshold", result->map_inference_threshold, 25); return result; } diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index 08471fdf..fac6d89d 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -1,23 +1,25 @@ #include "column_reader.hpp" #include "boolean_column_reader.hpp" +#include "brotli/decode.h" #include "callback_column_reader.hpp" #include "cast_column_reader.hpp" #include "duckdb.hpp" +#include "expression_column_reader.hpp" #include "list_column_reader.hpp" +#include "lz4.hpp" #include "miniz_wrapper.hpp" +#include "null_column_reader.hpp" #include "parquet_decimal_utils.hpp" #include "parquet_reader.hpp" #include "parquet_timestamp.hpp" #include "row_number_column_reader.hpp" #include "snappy.h" #include "string_column_reader.hpp" -#include "null_column_reader.hpp" #include "struct_column_reader.hpp" #include "templated_column_reader.hpp" #include "utf8proc_wrapper.hpp" #include "zstd.h" -#include "lz4.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/common/helper.hpp" @@ -343,7 +345,9 @@ void ColumnReader::DecompressInternal(CompressionCodec::type codec, const_data_p break; } case CompressionCodec::LZ4_RAW: { - auto res = duckdb_lz4::LZ4_decompress_safe(const_char_ptr_cast(src), char_ptr_cast(dst), src_size, dst_size); + auto res = + duckdb_lz4::LZ4_decompress_safe(const_char_ptr_cast(src), char_ptr_cast(dst), + UnsafeNumericCast(src_size), UnsafeNumericCast(dst_size)); if (res != NumericCast(dst_size)) { throw std::runtime_error("LZ4 decompression failure"); } @@ -373,12 +377,26 @@ void ColumnReader::DecompressInternal(CompressionCodec::type codec, const_data_p } break; } + case CompressionCodec::BROTLI: { + auto state = duckdb_brotli::BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + size_t total_out = 0; + auto src_size_size_t = NumericCast(src_size); + auto dst_size_size_t = NumericCast(dst_size); + + auto res = duckdb_brotli::BrotliDecoderDecompressStream(state, &src_size_size_t, &src, &dst_size_size_t, &dst, + &total_out); + if (res != duckdb_brotli::BROTLI_DECODER_RESULT_SUCCESS) { + throw std::runtime_error("Brotli Decompression failure"); + } + duckdb_brotli::BrotliDecoderDestroyInstance(state); + break; + } default: { std::stringstream codec_name; codec_name << codec; throw std::runtime_error("Unsupported compression codec \"" + codec_name.str() + - "\". Supported options are uncompressed, gzip, lz4_raw, snappy or zstd"); + "\". Supported options are uncompressed, brotli, gzip, lz4_raw, snappy or zstd"); } } } @@ -641,7 +659,7 @@ uint32_t StringColumnReader::VerifyString(const char *str_data, uint32_t str_len void StringColumnReader::Dictionary(shared_ptr data, idx_t num_entries) { dict = std::move(data); - dict_strings = unique_ptr(new string_t[num_entries]); + dict_strings = unsafe_unique_ptr(new string_t[num_entries]); for (idx_t dict_idx = 0; dict_idx < num_entries; dict_idx++) { uint32_t str_len; if (fixed_width_string_length == 0) { @@ -741,7 +759,7 @@ void StringColumnReader::DeltaByteArray(uint8_t *defines, idx_t num_values, parq result_mask.SetInvalid(row_idx + result_offset); continue; } - if (filter[row_idx + result_offset]) { + if (filter.test(row_idx + result_offset)) { if (delta_offset >= byte_array_count) { throw IOException("DELTA_BYTE_ARRAY - length mismatch between values and byte array lengths (attempted " "read of %d from %d entries) - corrupt file?", @@ -773,8 +791,7 @@ void StringColumnReader::PlainReference(shared_ptr plain_data, Vecto } string_t StringParquetValueConversion::DictRead(ByteBuffer &dict, uint32_t &offset, ColumnReader &reader) { - auto &dict_strings = reader.Cast().dict_strings; - return dict_strings[offset]; + return reader.Cast().dict_strings[offset]; } string_t StringParquetValueConversion::PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { @@ -794,6 +811,18 @@ void StringParquetValueConversion::PlainSkip(ByteBuffer &plain_data, ColumnReade plain_data.inc(str_len); } +bool StringParquetValueConversion::PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return true; +} + +string_t StringParquetValueConversion::UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { + return PlainRead(plain_data, reader); +} + +void StringParquetValueConversion::UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + PlainSkip(plain_data, reader); +} + //===--------------------------------------------------------------------===// // List Column Reader //===--------------------------------------------------------------------===// @@ -952,8 +981,9 @@ unique_ptr RowNumberColumnReader::Stats(idx_t row_group_idx_p, c row_group_offset_min += row_groups[i].num_rows; } - NumericStats::SetMin(stats, Value::BIGINT(row_group_offset_min)); - NumericStats::SetMax(stats, Value::BIGINT(row_group_offset_min + row_groups[row_group_idx_p].num_rows)); + NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); + NumericStats::SetMax( + stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + row_groups[row_group_idx_p].num_rows))); stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); return stats.ToUnique(); } @@ -972,7 +1002,7 @@ idx_t RowNumberColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, auto data_ptr = FlatVector::GetData(result); for (idx_t i = 0; i < num_values; i++) { - data_ptr[i] = row_group_offset++; + data_ptr[i] = UnsafeNumericCast(row_group_offset++); } return num_values; } @@ -1010,7 +1040,7 @@ idx_t CastColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data intermediate_vector.Flatten(amount); auto &validity = FlatVector::Validity(intermediate_vector); for (idx_t i = 0; i < amount; i++) { - if (!filter[i]) { + if (!filter.test(i)) { validity.SetInvalid(i); } } @@ -1019,14 +1049,33 @@ idx_t CastColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data bool all_succeeded = VectorOperations::DefaultTryCast(intermediate_vector, result, amount, &error_message); if (!all_succeeded) { string extended_error; - extended_error = - StringUtil::Format("In file \"%s\" the column \"%s\" has type %s, but we are trying to read it as type %s.", - reader.file_name, schema.name, intermediate_vector.GetType(), result.GetType()); - extended_error += "\nThis can happen when reading multiple Parquet files. The schema information is taken from " - "the first Parquet file by default. Possible solutions:\n"; - extended_error += "* Enable the union_by_name=True option to combine the schema of all Parquet files " - "(duckdb.org/docs/data/multiple_files/combining_schemas)\n"; - extended_error += "* Use a COPY statement to automatically derive types from an existing table."; + if (!reader.table_columns.empty()) { + // COPY .. FROM + extended_error = StringUtil::Format( + "In file \"%s\" the column \"%s\" has type %s, but we are trying to load it into column ", + reader.file_name, schema.name, intermediate_vector.GetType()); + if (FileIdx() < reader.table_columns.size()) { + extended_error += "\"" + reader.table_columns[FileIdx()] + "\" "; + } + extended_error += StringUtil::Format("with type %s.", result.GetType()); + extended_error += "\nThis means the Parquet schema does not match the schema of the table."; + extended_error += "\nPossible solutions:"; + extended_error += "\n* Insert by name instead of by position using \"INSERT INTO tbl BY NAME SELECT * FROM " + "read_parquet(...)\""; + extended_error += "\n* Manually specify which columns to insert using \"INSERT INTO tbl SELECT ... FROM " + "read_parquet(...)\""; + } else { + // read_parquet() with multiple files + extended_error = StringUtil::Format( + "In file \"%s\" the column \"%s\" has type %s, but we are trying to read it as type %s.", + reader.file_name, schema.name, intermediate_vector.GetType(), result.GetType()); + extended_error += + "\nThis can happen when reading multiple Parquet files. The schema information is taken from " + "the first Parquet file by default. Possible solutions:\n"; + extended_error += "* Enable the union_by_name=True option to combine the schema of all Parquet files " + "(duckdb.org/docs/data/multiple_files/combining_schemas)\n"; + extended_error += "* Use a COPY statement to automatically derive types from an existing table."; + } throw ConversionException( "In Parquet reader of file \"%s\": failed to cast column \"%s\" from type %s to %s: %s\n\n%s", reader.file_name, schema.name, intermediate_vector.GetType(), result.GetType(), error_message, @@ -1043,6 +1092,59 @@ idx_t CastColumnReader::GroupRowsAvailable() { return child_reader->GroupRowsAvailable(); } +//===--------------------------------------------------------------------===// +// Expression Column Reader +//===--------------------------------------------------------------------===// +ExpressionColumnReader::ExpressionColumnReader(ClientContext &context, unique_ptr child_reader_p, + unique_ptr expr_p) + : ColumnReader(child_reader_p->Reader(), expr_p->return_type, child_reader_p->Schema(), child_reader_p->FileIdx(), + child_reader_p->MaxDefine(), child_reader_p->MaxRepeat()), + child_reader(std::move(child_reader_p)), expr(std::move(expr_p)), executor(context, expr.get()) { + vector intermediate_types {child_reader->Type()}; + intermediate_chunk.Initialize(reader.allocator, intermediate_types); +} + +unique_ptr ExpressionColumnReader::Stats(idx_t row_group_idx_p, const vector &columns) { + // expression stats is not supported (yet) + return nullptr; +} + +void ExpressionColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, + TProtocol &protocol_p) { + child_reader->InitializeRead(row_group_idx_p, columns, protocol_p); +} + +idx_t ExpressionColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, + data_ptr_t repeat_out, Vector &result) { + intermediate_chunk.Reset(); + auto &intermediate_vector = intermediate_chunk.data[0]; + + auto amount = child_reader->Read(num_values, filter, define_out, repeat_out, intermediate_vector); + if (!filter.all()) { + // work-around for filters: set all values that are filtered to NULL to prevent the cast from failing on + // uninitialized data + intermediate_vector.Flatten(amount); + auto &validity = FlatVector::Validity(intermediate_vector); + for (idx_t i = 0; i < amount; i++) { + if (!filter[i]) { + validity.SetInvalid(i); + } + } + } + // Execute the expression + intermediate_chunk.SetCardinality(amount); + executor.ExecuteExpression(intermediate_chunk, result); + return amount; +} + +void ExpressionColumnReader::Skip(idx_t num_values) { + child_reader->Skip(num_values); +} + +idx_t ExpressionColumnReader::GroupRowsAvailable() { + return child_reader->GroupRowsAvailable(); +} + //===--------------------------------------------------------------------===// // Struct Column Reader //===--------------------------------------------------------------------===// @@ -1147,8 +1249,7 @@ idx_t StructColumnReader::GroupRowsAvailable() { template struct DecimalParquetValueConversion { static DUCKDB_PHYSICAL_TYPE DictRead(ByteBuffer &dict, uint32_t &offset, ColumnReader &reader) { - auto dict_ptr = reinterpret_cast(dict.ptr); - return dict_ptr[offset]; + return reinterpret_cast(dict.ptr)[offset]; } static DUCKDB_PHYSICAL_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { @@ -1170,6 +1271,18 @@ struct DecimalParquetValueConversion { uint32_t decimal_len = FIXED_LENGTH ? reader.Schema().type_length : plain_data.read(); plain_data.inc(decimal_len); } + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return true; + } + + static DUCKDB_PHYSICAL_TYPE UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { + return PlainRead(plain_data, reader); + } + + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + PlainSkip(plain_data, reader); + } }; template @@ -1236,7 +1349,7 @@ double ParquetDecimalUtils::ReadDecimalValue(const_data_ptr_t pointer, idx_t siz res_ptr[sizeof(uint64_t) - k - 1] = positive ? byte : byte ^ 0xFF; } res *= double(NumericLimits::Maximum()) + 1; - res += input; + res += static_cast(input); } if (!positive) { res += 1; @@ -1262,8 +1375,7 @@ unique_ptr ParquetDecimalUtils::CreateReader(ParquetReader &reader //===--------------------------------------------------------------------===// struct UUIDValueConversion { static hugeint_t DictRead(ByteBuffer &dict, uint32_t &offset, ColumnReader &reader) { - auto dict_ptr = reinterpret_cast(dict.ptr); - return dict_ptr[offset]; + return reinterpret_cast(dict.ptr)[offset]; } static hugeint_t ReadParquetUUID(const_data_ptr_t input) { @@ -1278,23 +1390,32 @@ struct UUIDValueConversion { result.lower <<= 8; result.lower += input[i]; } - result.upper = unsigned_upper; - result.upper ^= (int64_t(1) << 63); + result.upper = static_cast(unsigned_upper ^ (uint64_t(1) << 63)); return result; } static hugeint_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - idx_t byte_len = sizeof(hugeint_t); - plain_data.available(byte_len); - auto res = ReadParquetUUID(const_data_ptr_cast(plain_data.ptr)); - - plain_data.inc(byte_len); - return res; + plain_data.available(sizeof(hugeint_t)); + return UnsafePlainRead(plain_data, reader); } static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { plain_data.inc(sizeof(hugeint_t)); } + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return plain_data.check_available(count * sizeof(hugeint_t)); + } + + static hugeint_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { + auto res = ReadParquetUUID(const_data_ptr_cast(plain_data.ptr)); + plain_data.unsafe_inc(sizeof(hugeint_t)); + return res; + } + + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + plain_data.unsafe_inc(sizeof(hugeint_t)); + } }; class UUIDColumnReader : public TemplatedColumnReader { @@ -1322,30 +1443,39 @@ struct IntervalValueConversion { static constexpr const idx_t PARQUET_INTERVAL_SIZE = 12; static interval_t DictRead(ByteBuffer &dict, uint32_t &offset, ColumnReader &reader) { - auto dict_ptr = reinterpret_cast(dict.ptr); - return dict_ptr[offset]; + return reinterpret_cast(dict.ptr)[offset]; } static interval_t ReadParquetInterval(const_data_ptr_t input) { interval_t result; - result.months = Load(input); - result.days = Load(input + sizeof(uint32_t)); + result.months = Load(input); + result.days = Load(input + sizeof(uint32_t)); result.micros = int64_t(Load(input + sizeof(uint32_t) * 2)) * 1000; return result; } static interval_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - idx_t byte_len = PARQUET_INTERVAL_SIZE; - plain_data.available(byte_len); - auto res = ReadParquetInterval(const_data_ptr_cast(plain_data.ptr)); - - plain_data.inc(byte_len); - return res; + plain_data.available(PARQUET_INTERVAL_SIZE); + return UnsafePlainRead(plain_data, reader); } static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { plain_data.inc(PARQUET_INTERVAL_SIZE); } + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return plain_data.check_available(count * PARQUET_INTERVAL_SIZE); + } + + static interval_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { + auto res = ReadParquetInterval(const_data_ptr_cast(plain_data.ptr)); + plain_data.unsafe_inc(PARQUET_INTERVAL_SIZE); + return res; + } + + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + plain_data.unsafe_inc(PARQUET_INTERVAL_SIZE); + } }; class IntervalColumnReader : public TemplatedColumnReader { @@ -1464,6 +1594,39 @@ unique_ptr ColumnReader::CreateReader(ParquetReader &reader, const break; } break; + case LogicalTypeId::TIMESTAMP_NS: + switch (schema_p.type) { + case Type::INT96: + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + case Type::INT64: + if (schema_p.__isset.logicalType && schema_p.logicalType.__isset.TIMESTAMP) { + if (schema_p.logicalType.TIMESTAMP.unit.__isset.MILLIS) { + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.MICROS) { + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.NANOS) { + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + } + } else if (schema_p.__isset.converted_type) { + switch (schema_p.converted_type) { + case ConvertedType::TIMESTAMP_MICROS: + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + case ConvertedType::TIMESTAMP_MILLIS: + return make_uniq>( + reader, type_p, schema_p, file_idx_p, max_define, max_repeat); + default: + break; + } + } + default: + break; + } + break; case LogicalTypeId::DATE: return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 11ded115..208aa043 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -4,27 +4,26 @@ #include "parquet_rle_bp_decoder.hpp" #include "parquet_rle_bp_encoder.hpp" #include "parquet_writer.hpp" +#include "geo_parquet.hpp" #ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/mutex.hpp" #include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/serializer/buffered_file_writer.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/string_map_set.hpp" -#include "duckdb/common/types/date.hpp" #include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/string_heap.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/execution/expression_executor.hpp" #endif #include "lz4.hpp" #include "miniz_wrapper.hpp" #include "snappy.h" #include "zstd.h" +#include "brotli/encode.h" namespace duckdb { @@ -207,11 +206,11 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si break; } case CompressionCodec::LZ4_RAW: { - compressed_size = duckdb_lz4::LZ4_compressBound(temp_writer.GetPosition()); + compressed_size = duckdb_lz4::LZ4_compressBound(UnsafeNumericCast(temp_writer.GetPosition())); compressed_buf = unique_ptr(new data_t[compressed_size]); - compressed_size = duckdb_lz4::LZ4_compress_default(const_char_ptr_cast(temp_writer.GetData()), - char_ptr_cast(compressed_buf.get()), - temp_writer.GetPosition(), compressed_size); + compressed_size = duckdb_lz4::LZ4_compress_default( + const_char_ptr_cast(temp_writer.GetData()), char_ptr_cast(compressed_buf.get()), + UnsafeNumericCast(temp_writer.GetPosition()), UnsafeNumericCast(compressed_size)); compressed_data = compressed_buf.get(); break; } @@ -238,6 +237,18 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si compressed_data = compressed_buf.get(); break; } + case CompressionCodec::BROTLI: { + + compressed_size = duckdb_brotli::BrotliEncoderMaxCompressedSize(temp_writer.GetPosition()); + compressed_buf = unique_ptr(new data_t[compressed_size]); + + duckdb_brotli::BrotliEncoderCompress(BROTLI_DEFAULT_QUALITY, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE, + temp_writer.GetPosition(), temp_writer.GetData(), &compressed_size, + compressed_buf.get()); + compressed_data = compressed_buf.get(); + + break; + } default: throw InternalException("Unsupported codec for Parquet Writer"); } @@ -249,7 +260,7 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si } void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count, - idx_t max_repeat) { + idx_t max_repeat) const { if (!parent) { // no repeat levels without a parent node return; @@ -259,8 +270,8 @@ void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterStat } } -void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, ValidityMask &validity, - idx_t count, uint16_t define_value, uint16_t null_value) { +void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, + const idx_t count, const uint16_t define_value, const uint16_t null_value) const { if (parent) { // parent node: inherit definition level from the parent idx_t vector_index = 0; @@ -284,15 +295,12 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat } else { // no parent: set definition levels only from this validity mask for (idx_t i = 0; i < count; i++) { - if (validity.RowIsValid(i)) { - state.definition_levels.push_back(define_value); - } else { - if (!can_have_nulls) { - throw IOException("Parquet writer: map key column is not allowed to contain NULL values"); - } - state.null_count++; - state.definition_levels.push_back(null_value); - } + const auto is_null = !validity.RowIsValid(i); + state.definition_levels.emplace_back(is_null ? null_value : define_value); + state.null_count += is_null; + } + if (!can_have_nulls && state.null_count != 0) { + throw IOException("Parquet writer: map key column is not allowed to contain NULL values"); } } } @@ -386,8 +394,8 @@ class BasicColumnWriter : public ColumnWriter { void FinalizeWrite(ColumnWriterState &state) override; protected: - void WriteLevels(WriteStream &temp_writer, const vector &levels, idx_t max_value, idx_t start_offset, - idx_t count); + static void WriteLevels(WriteStream &temp_writer, const unsafe_vector &levels, idx_t max_value, + idx_t start_offset, idx_t count); virtual duckdb_parquet::format::Encoding::type GetEncoding(BasicColumnWriterState &state); @@ -404,7 +412,7 @@ class BasicColumnWriter : public ColumnWriter { virtual void FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state); //! Retrieves the row size of a vector at the specified location. Only used for scalar types. - virtual idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state); + virtual idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const; //! Writes a (subset of a) vector to the specified serializer. Only used for scalar types. virtual void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats, ColumnWriterPageState *page_state, Vector &vector, idx_t chunk_start, idx_t chunk_end) = 0; @@ -456,8 +464,9 @@ void BasicColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p HandleDefineLevels(state, parent, validity, count, max_define, max_define - 1); idx_t vector_index = 0; + reference page_info_ref = state.page_info.back(); for (idx_t i = start; i < vcount; i++) { - auto &page_info = state.page_info.back(); + auto &page_info = page_info_ref.get(); page_info.row_count++; col_chunk.meta_data.num_values++; if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index + i]) { @@ -470,6 +479,7 @@ void BasicColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p PageInformation new_info; new_info.offset = page_info.offset + page_info.row_count; state.page_info.push_back(new_info); + page_info_ref = state.page_info.back(); } } vector_index++; @@ -500,12 +510,13 @@ void BasicColumnWriter::BeginWrite(ColumnWriterState &state_p) { hdr.type = PageType::DATA_PAGE; hdr.__isset.data_page_header = true; - hdr.data_page_header.num_values = page_info.row_count; + hdr.data_page_header.num_values = UnsafeNumericCast(page_info.row_count); hdr.data_page_header.encoding = GetEncoding(state); hdr.data_page_header.definition_level_encoding = Encoding::RLE; hdr.data_page_header.repetition_level_encoding = Encoding::RLE; - write_info.temp_writer = make_uniq(); + write_info.temp_writer = make_uniq( + MaxValue(NextPowerOfTwo(page_info.estimated_page_size), MemoryStream::DEFAULT_INITIAL_CAPACITY)); write_info.write_count = page_info.empty_count; write_info.max_write_count = page_info.row_count; write_info.page_state = InitializePageState(state); @@ -520,7 +531,7 @@ void BasicColumnWriter::BeginWrite(ColumnWriterState &state_p) { NextPage(state); } -void BasicColumnWriter::WriteLevels(WriteStream &temp_writer, const vector &levels, idx_t max_value, +void BasicColumnWriter::WriteLevels(WriteStream &temp_writer, const unsafe_vector &levels, idx_t max_value, idx_t offset, idx_t count) { if (levels.empty() || count == 0) { return; @@ -585,11 +596,11 @@ void BasicColumnWriter::FlushPage(BasicColumnWriterState &state) { throw InternalException("Parquet writer: %d uncompressed page size out of range for type integer", temp_writer.GetPosition()); } - hdr.uncompressed_page_size = temp_writer.GetPosition(); + hdr.uncompressed_page_size = UnsafeNumericCast(temp_writer.GetPosition()); // compress the data CompressPage(temp_writer, write_info.compressed_size, write_info.compressed_data, write_info.compressed_buf); - hdr.compressed_page_size = write_info.compressed_size; + hdr.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); D_ASSERT(hdr.uncompressed_page_size > 0); D_ASSERT(hdr.compressed_page_size > 0); @@ -604,7 +615,8 @@ unique_ptr BasicColumnWriter::InitializeStatsState() { return make_uniq(); } -idx_t BasicColumnWriter::GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) { +idx_t BasicColumnWriter::GetRowSize(const Vector &vector, const idx_t index, + const BasicColumnWriterState &state) const { throw InternalException("GetRowSize unsupported for struct/list column writers"); } @@ -665,7 +677,7 @@ void BasicColumnWriter::SetParquetStatistics(BasicColumnWriterState &state, column_chunk.meta_data.__isset.statistics = true; } if (HasDictionary(state)) { - column_chunk.meta_data.statistics.distinct_count = DictionarySize(state); + column_chunk.meta_data.statistics.distinct_count = UnsafeNumericCast(DictionarySize(state)); column_chunk.meta_data.statistics.__isset.distinct_count = true; column_chunk.meta_data.__isset.statistics = true; } @@ -685,9 +697,9 @@ void BasicColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { auto start_offset = column_writer.GetTotalWritten(); // flush the dictionary if (HasDictionary(state)) { - column_chunk.meta_data.statistics.distinct_count = DictionarySize(state); + column_chunk.meta_data.statistics.distinct_count = UnsafeNumericCast(DictionarySize(state)); column_chunk.meta_data.statistics.__isset.distinct_count = true; - column_chunk.meta_data.dictionary_page_offset = column_writer.GetTotalWritten(); + column_chunk.meta_data.dictionary_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); column_chunk.meta_data.__isset.dictionary_page_offset = true; FlushDictionary(state, state.stats_state.get()); } @@ -702,7 +714,7 @@ void BasicColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { // set the data page offset whenever we see the *first* data page if (column_chunk.meta_data.data_page_offset == 0 && (write_info.page_header.type == PageType::DATA_PAGE || write_info.page_header.type == PageType::DATA_PAGE_V2)) { - column_chunk.meta_data.data_page_offset = column_writer.GetTotalWritten(); + column_chunk.meta_data.data_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); ; } D_ASSERT(write_info.page_header.uncompressed_page_size > 0); @@ -713,8 +725,9 @@ void BasicColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { total_uncompressed_size += write_info.page_header.uncompressed_page_size; writer.WriteData(write_info.compressed_data, write_info.compressed_size); } - column_chunk.meta_data.total_compressed_size = column_writer.GetTotalWritten() - start_offset; - column_chunk.meta_data.total_uncompressed_size = total_uncompressed_size; + column_chunk.meta_data.total_compressed_size = + UnsafeNumericCast(column_writer.GetTotalWritten() - start_offset); + column_chunk.meta_data.total_uncompressed_size = UnsafeNumericCast(total_uncompressed_size); } void BasicColumnWriter::FlushDictionary(BasicColumnWriterState &state, ColumnWriterStatistics *stats) { @@ -734,13 +747,13 @@ void BasicColumnWriter::WriteDictionary(BasicColumnWriterState &state, unique_pt PageWriteInformation write_info; // set up the header auto &hdr = write_info.page_header; - hdr.uncompressed_page_size = temp_writer->GetPosition(); + hdr.uncompressed_page_size = UnsafeNumericCast(temp_writer->GetPosition()); hdr.type = PageType::DICTIONARY_PAGE; hdr.__isset.dictionary_page_header = true; hdr.dictionary_page_header.encoding = Encoding::PLAIN; hdr.dictionary_page_header.is_sorted = false; - hdr.dictionary_page_header.num_values = row_count; + hdr.dictionary_page_header.num_values = UnsafeNumericCast(row_count); write_info.temp_writer = std::move(temp_writer); write_info.write_count = 0; @@ -749,7 +762,7 @@ void BasicColumnWriter::WriteDictionary(BasicColumnWriterState &state, unique_pt // compress the contents of the dictionary page CompressPage(*write_info.temp_writer, write_info.compressed_size, write_info.compressed_data, write_info.compressed_buf); - hdr.compressed_page_size = write_info.compressed_size; + hdr.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); // insert the dictionary page as the first page to write for this column state.write_info.insert(state.write_info.begin(), std::move(write_info)); @@ -814,7 +827,7 @@ struct ParquetCastOperator : public BaseParquetOperator { struct ParquetTimestampNSOperator : public BaseParquetOperator { template static TGT Operation(SRC input) { - return Timestamp::FromEpochNanoSecondsPossiblyInfinite(input).value; + return TGT(input); } }; @@ -865,16 +878,26 @@ struct ParquetUhugeintOperator { }; template -static void TemplatedWritePlain(Vector &col, ColumnWriterStatistics *stats, idx_t chunk_start, idx_t chunk_end, - ValidityMask &mask, WriteStream &ser) { - auto *ptr = FlatVector::GetData(col); +static void TemplatedWritePlain(Vector &col, ColumnWriterStatistics *stats, const idx_t chunk_start, + const idx_t chunk_end, ValidityMask &mask, WriteStream &ser) { + static constexpr idx_t WRITE_COMBINER_CAPACITY = 8; + TGT write_combiner[WRITE_COMBINER_CAPACITY]; + idx_t write_combiner_count = 0; + + const auto *ptr = FlatVector::GetData(col); for (idx_t r = chunk_start; r < chunk_end; r++) { - if (mask.RowIsValid(r)) { - TGT target_value = OP::template Operation(ptr[r]); - OP::template HandleStats(stats, ptr[r], target_value); - ser.Write(target_value); + if (!mask.RowIsValid(r)) { + continue; + } + TGT target_value = OP::template Operation(ptr[r]); + OP::template HandleStats(stats, ptr[r], target_value); + write_combiner[write_combiner_count++] = target_value; + if (write_combiner_count == WRITE_COMBINER_CAPACITY) { + ser.WriteData(const_data_ptr_cast(write_combiner), WRITE_COMBINER_CAPACITY * sizeof(TGT)); + write_combiner_count = 0; } } + ser.WriteData(const_data_ptr_cast(write_combiner), write_combiner_count * sizeof(TGT)); } template @@ -897,7 +920,7 @@ class StandardColumnWriter : public BasicColumnWriter { TemplatedWritePlain(input_column, stats, chunk_start, chunk_end, mask, temp_writer); } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return sizeof(TGT); } }; @@ -991,7 +1014,7 @@ class BooleanColumnWriter : public BasicColumnWriter { } } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return sizeof(bool); } }; @@ -1092,7 +1115,7 @@ class FixedDecimalColumnWriter : public BasicColumnWriter { } } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return sizeof(hugeint_t); } }; @@ -1139,7 +1162,7 @@ class UUIDColumnWriter : public BasicColumnWriter { } } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return PARQUET_UUID_SIZE; } }; @@ -1181,11 +1204,52 @@ class IntervalColumnWriter : public BasicColumnWriter { } } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return PARQUET_INTERVAL_SIZE; } }; +//===--------------------------------------------------------------------===// +// Geometry Column Writer +//===--------------------------------------------------------------------===// +// This class just wraps another column writer, but also calculates the extent +// of the geometry column by updating the geodata object with every written +// vector. +template +class GeometryColumnWriter : public WRITER_IMPL { + GeoParquetColumnMetadata geo_data; + GeoParquetColumnMetadataWriter geo_data_writer; + string column_name; + +public: + void Write(ColumnWriterState &state, Vector &vector, idx_t count) override { + // Just write normally + WRITER_IMPL::Write(state, vector, count); + + // And update the geodata object + geo_data_writer.Update(geo_data, vector, count); + } + void FinalizeWrite(ColumnWriterState &state) override { + WRITER_IMPL::FinalizeWrite(state); + + // Add the geodata object to the writer + this->writer.GetGeoParquetData().geometry_columns[column_name] = geo_data; + } + +public: + GeometryColumnWriter(ClientContext &context, ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, + idx_t max_repeat, idx_t max_define, bool can_have_nulls, string name) + : WRITER_IMPL(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls), + geo_data_writer(context), column_name(std::move(name)) { + + auto &geo_data = writer.GetGeoParquetData(); + if (geo_data.primary_geometry_column.empty()) { + // Set the first column to the primary column + geo_data.primary_geometry_column = column_name; + } + } +}; + //===--------------------------------------------------------------------===// // String Column Writer //===--------------------------------------------------------------------===// @@ -1262,7 +1326,7 @@ class StringColumnWriterState : public BasicColumnWriterState { // key_bit_width== 0 signifies the chunk is written in plain encoding uint32_t key_bit_width; - bool IsDictionaryEncoded() { + bool IsDictionaryEncoded() const { return key_bit_width != 0; } }; @@ -1457,7 +1521,8 @@ class StringColumnWriter : public BasicColumnWriter { values[entry.second] = entry.first; } // first write the contents of the dictionary page to a temporary buffer - auto temp_writer = make_uniq(); + auto temp_writer = make_uniq( + MaxValue(NextPowerOfTwo(state.estimated_dict_page_size), MemoryStream::DEFAULT_INITIAL_CAPACITY)); for (idx_t r = 0; r < values.size(); r++) { auto &value = values[r]; // update the statistics @@ -1470,7 +1535,7 @@ class StringColumnWriter : public BasicColumnWriter { WriteDictionary(state, std::move(temp_writer), values.size()); } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state_p) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state_p) const override { auto &state = state_p.Cast(); if (state.IsDictionaryEncoded()) { return (state.key_bit_width + 7) / 8; @@ -1614,7 +1679,7 @@ class EnumColumnWriter : public BasicColumnWriter { WriteDictionary(state, std::move(temp_writer), enum_count); } - idx_t GetRowSize(Vector &vector, idx_t index, BasicColumnWriterState &state) override { + idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { return (bit_width + 7) / 8; } }; @@ -2006,7 +2071,8 @@ void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t // Create Column Writer //===--------------------------------------------------------------------===// -unique_ptr ColumnWriter::CreateWriterRecursive(vector &schemas, +unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &context, + vector &schemas, ParquetWriter &writer, const LogicalType &type, const string &name, vector schema_path, optional_ptr field_ids, @@ -2032,7 +2098,7 @@ unique_ptr ColumnWriter::CreateWriterRecursive(vector(child_types.size()); schema_element.__isset.num_children = true; schema_element.__isset.type = false; schema_element.__isset.repetition_type = true; @@ -2048,7 +2114,7 @@ unique_ptr ColumnWriter::CreateWriterRecursive(vector> child_writers; child_writers.reserve(child_types.size()); for (auto &child_type : child_types) { - child_writers.push_back(CreateWriterRecursive(schemas, writer, child_type.second, child_type.first, + child_writers.push_back(CreateWriterRecursive(context, schemas, writer, child_type.second, child_type.first, schema_path, child_field_ids, max_repeat, max_define + 1)); } return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, @@ -2087,8 +2153,8 @@ unique_ptr ColumnWriter::CreateWriterRecursive(vector(writer, schema_idx, std::move(schema_path), max_repeat, max_define, std::move(child_writer), can_have_nulls); @@ -2142,7 +2208,7 @@ unique_ptr ColumnWriter::CreateWriterRecursive(vector ColumnWriter::CreateWriterRecursive(vector>(context, writer, schema_idx, std::move(schema_path), + max_repeat, max_define, can_have_nulls, name); + } + switch (type.id()) { case LogicalTypeId::BOOLEAN: return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/geo_parquet.cpp new file mode 100644 index 00000000..6b47060e --- /dev/null +++ b/src/duckdb/extension/parquet/geo_parquet.cpp @@ -0,0 +1,391 @@ + +#include "geo_parquet.hpp" + +#include "column_reader.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "expression_column_reader.hpp" +#include "parquet_reader.hpp" +#include "yyjson.hpp" + +namespace duckdb { + +using namespace duckdb_yyjson; // NOLINT + +const char *WKBGeometryTypes::ToString(WKBGeometryType type) { + switch (type) { + case WKBGeometryType::POINT: + return "Point"; + case WKBGeometryType::LINESTRING: + return "LineString"; + case WKBGeometryType::POLYGON: + return "Polygon"; + case WKBGeometryType::MULTIPOINT: + return "MultiPoint"; + case WKBGeometryType::MULTILINESTRING: + return "MultiLineString"; + case WKBGeometryType::MULTIPOLYGON: + return "MultiPolygon"; + case WKBGeometryType::GEOMETRYCOLLECTION: + return "GeometryCollection"; + case WKBGeometryType::POINT_Z: + return "Point Z"; + case WKBGeometryType::LINESTRING_Z: + return "LineString Z"; + case WKBGeometryType::POLYGON_Z: + return "Polygon Z"; + case WKBGeometryType::MULTIPOINT_Z: + return "MultiPoint Z"; + case WKBGeometryType::MULTILINESTRING_Z: + return "MultiLineString Z"; + case WKBGeometryType::MULTIPOLYGON_Z: + return "MultiPolygon Z"; + case WKBGeometryType::GEOMETRYCOLLECTION_Z: + return "GeometryCollection Z"; + default: + throw NotImplementedException("Unsupported geometry type"); + } +} + +//------------------------------------------------------------------------------ +// GeoParquetColumnMetadataWriter +//------------------------------------------------------------------------------ +GeoParquetColumnMetadataWriter::GeoParquetColumnMetadataWriter(ClientContext &context) { + executor = make_uniq(context); + + auto &catalog = Catalog::GetSystemCatalog(context); + + // These functions are required to extract the geometry type, ZM flag and bounding box from a WKB blob + auto &type_func_set = + catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_geometrytype") + .Cast(); + auto &flag_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_zmflag") + .Cast(); + auto &bbox_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_extent") + .Cast(); + + auto wkb_type = LogicalType(LogicalTypeId::BLOB); + wkb_type.SetAlias("WKB_BLOB"); + + auto type_func = type_func_set.functions.GetFunctionByArguments(context, {wkb_type}); + auto flag_func = flag_func_set.functions.GetFunctionByArguments(context, {wkb_type}); + auto bbox_func = bbox_func_set.functions.GetFunctionByArguments(context, {wkb_type}); + + auto type_type = LogicalType::UTINYINT; + auto flag_type = flag_func.return_type; + auto bbox_type = bbox_func.return_type; + + vector> type_args; + type_args.push_back(make_uniq(wkb_type, 0)); + + vector> flag_args; + flag_args.push_back(make_uniq(wkb_type, 0)); + + vector> bbox_args; + bbox_args.push_back(make_uniq(wkb_type, 0)); + + type_expr = make_uniq(type_type, type_func, std::move(type_args), nullptr); + flag_expr = make_uniq(flag_type, flag_func, std::move(flag_args), nullptr); + bbox_expr = make_uniq(bbox_type, bbox_func, std::move(bbox_args), nullptr); + + // Add the expressions to the executor + executor->AddExpression(*type_expr); + executor->AddExpression(*flag_expr); + executor->AddExpression(*bbox_expr); + + // Initialize the input and result chunks + // The input chunk should be empty, as we always reference the input vector + input_chunk.InitializeEmpty({wkb_type}); + result_chunk.Initialize(context, {type_type, flag_type, bbox_type}); +} + +void GeoParquetColumnMetadataWriter::Update(GeoParquetColumnMetadata &meta, Vector &vector, idx_t count) { + input_chunk.Reset(); + result_chunk.Reset(); + + // Reference the vector + input_chunk.data[0].Reference(vector); + input_chunk.SetCardinality(count); + + // Execute the expression + executor->Execute(input_chunk, result_chunk); + + // The first column is the geometry type + // The second column is the zm flag + // The third column is the bounding box + + UnifiedVectorFormat type_format; + UnifiedVectorFormat flag_format; + UnifiedVectorFormat bbox_format; + + result_chunk.data[0].ToUnifiedFormat(count, type_format); + result_chunk.data[1].ToUnifiedFormat(count, flag_format); + result_chunk.data[2].ToUnifiedFormat(count, bbox_format); + + const auto &bbox_components = StructVector::GetEntries(result_chunk.data[2]); + D_ASSERT(bbox_components.size() == 4); + + UnifiedVectorFormat xmin_format; + UnifiedVectorFormat ymin_format; + UnifiedVectorFormat xmax_format; + UnifiedVectorFormat ymax_format; + + bbox_components[0]->ToUnifiedFormat(count, xmin_format); + bbox_components[1]->ToUnifiedFormat(count, ymin_format); + bbox_components[2]->ToUnifiedFormat(count, xmax_format); + bbox_components[3]->ToUnifiedFormat(count, ymax_format); + + for (idx_t in_idx = 0; in_idx < count; in_idx++) { + const auto type_idx = type_format.sel->get_index(in_idx); + const auto flag_idx = flag_format.sel->get_index(in_idx); + const auto bbox_idx = bbox_format.sel->get_index(in_idx); + + const auto type_valid = type_format.validity.RowIsValid(type_idx); + const auto flag_valid = flag_format.validity.RowIsValid(flag_idx); + const auto bbox_valid = bbox_format.validity.RowIsValid(bbox_idx); + + if (!type_valid || !flag_valid || !bbox_valid) { + continue; + } + + // Update the geometry type + const auto flag = UnifiedVectorFormat::GetData(flag_format)[flag_idx]; + const auto type = UnifiedVectorFormat::GetData(type_format)[type_idx]; + if (flag == 1 || flag == 3) { + // M or ZM + throw InvalidInputException("Geoparquet does not support geometries with M coordinates"); + } + const auto has_z = flag == 2; + auto wkb_type = static_cast((type + 1) + (has_z ? 1000 : 0)); + meta.geometry_types.insert(wkb_type); + + // Update the bounding box + const auto min_x = UnifiedVectorFormat::GetData(xmin_format)[bbox_idx]; + const auto min_y = UnifiedVectorFormat::GetData(ymin_format)[bbox_idx]; + const auto max_x = UnifiedVectorFormat::GetData(xmax_format)[bbox_idx]; + const auto max_y = UnifiedVectorFormat::GetData(ymax_format)[bbox_idx]; + meta.bbox.Combine(min_x, max_x, min_y, max_y); + } +} + +//------------------------------------------------------------------------------ +// GeoParquetFileMetadata +//------------------------------------------------------------------------------ + +unique_ptr +GeoParquetFileMetadata::TryRead(const duckdb_parquet::format::FileMetaData &file_meta_data, ClientContext &context) { + for (auto &kv : file_meta_data.key_value_metadata) { + if (kv.key == "geo") { + const auto geo_metadata = yyjson_read(kv.value.c_str(), kv.value.size(), 0); + if (!geo_metadata) { + // Could not parse the JSON + return nullptr; + } + + // Check if the spatial extension is loaded, or try to autoload it. + const auto is_loaded = ExtensionHelper::TryAutoLoadExtension(context, "spatial"); + if (!is_loaded) { + // Spatial extension is not available, we can't make use of the metadata anyway. + yyjson_doc_free(geo_metadata); + return nullptr; + } + + try { + // Check the root object + const auto root = yyjson_doc_get_root(geo_metadata); + if (!yyjson_is_obj(root)) { + throw InvalidInputException("Geoparquet metadata is not an object"); + } + + auto result = make_uniq(); + + // Check and parse the version + const auto version_val = yyjson_obj_get(root, "version"); + if (!yyjson_is_str(version_val)) { + throw InvalidInputException("Geoparquet metadata does not have a version"); + } + result->version = yyjson_get_str(version_val); + if (StringUtil::StartsWith(result->version, "2")) { + // Guard against a breaking future 2.0 version + throw InvalidInputException("Geoparquet version %s is not supported", result->version); + } + + // Check and parse the primary geometry column + const auto primary_geometry_column_val = yyjson_obj_get(root, "primary_column"); + if (!yyjson_is_str(primary_geometry_column_val)) { + throw InvalidInputException("Geoparquet metadata does not have a primary column"); + } + result->primary_geometry_column = yyjson_get_str(primary_geometry_column_val); + + // Check and parse the geometry columns + const auto columns_val = yyjson_obj_get(root, "columns"); + if (!yyjson_is_obj(columns_val)) { + throw InvalidInputException("Geoparquet metadata does not have a columns object"); + } + + // Iterate over all geometry columns + yyjson_obj_iter iter = yyjson_obj_iter_with(columns_val); + yyjson_val *column_key; + + while ((column_key = yyjson_obj_iter_next(&iter))) { + const auto column_val = yyjson_obj_iter_get_val(column_key); + const auto column_name = yyjson_get_str(column_key); + + auto &column = result->geometry_columns[column_name]; + + if (!yyjson_is_obj(column_val)) { + throw InvalidInputException("Geoparquet column '%s' is not an object", column_name); + } + + // Parse the encoding + const auto encoding_val = yyjson_obj_get(column_val, "encoding"); + if (!yyjson_is_str(encoding_val)) { + throw InvalidInputException("Geoparquet column '%s' does not have an encoding", column_name); + } + const auto encoding_str = yyjson_get_str(encoding_val); + if (strcmp(encoding_str, "WKB") == 0) { + column.geometry_encoding = GeoParquetColumnEncoding::WKB; + } else { + throw InvalidInputException("Geoparquet column '%s' has an unsupported encoding", column_name); + } + + // Parse the geometry types + const auto geometry_types_val = yyjson_obj_get(column_val, "geometry_types"); + if (!yyjson_is_arr(geometry_types_val)) { + throw InvalidInputException("Geoparquet column '%s' does not have geometry types", column_name); + } + // We dont care about the geometry types for now. + + // TODO: Parse the bounding box, other metadata that might be useful. + // (Only encoding and geometry types are required to be present) + } + + // Return the result + // Make sure to free the JSON document + yyjson_doc_free(geo_metadata); + return result; + + } catch (...) { + // Make sure to free the JSON document in case of an exception + yyjson_doc_free(geo_metadata); + throw; + } + } + } + return nullptr; +} + +void GeoParquetFileMetadata::Write(duckdb_parquet::format::FileMetaData &file_meta_data) const { + + yyjson_mut_doc *doc = yyjson_mut_doc_new(nullptr); + yyjson_mut_val *root = yyjson_mut_obj(doc); + yyjson_mut_doc_set_root(doc, root); + + // Add the version + yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); + + // Add the primary column + yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), + primary_geometry_column.size()); + + // Add the columns + const auto json_columns = yyjson_mut_obj_add_obj(doc, root, "columns"); + + for (auto &column : geometry_columns) { + const auto column_json = yyjson_mut_obj_add_obj(doc, json_columns, column.first.c_str()); + yyjson_mut_obj_add_str(doc, column_json, "encoding", "WKB"); + const auto geometry_types = yyjson_mut_obj_add_arr(doc, column_json, "geometry_types"); + for (auto &geometry_type : column.second.geometry_types) { + const auto type_name = WKBGeometryTypes::ToString(geometry_type); + yyjson_mut_arr_add_str(doc, geometry_types, type_name); + } + const auto bbox = yyjson_mut_obj_add_arr(doc, column_json, "bbox"); + yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.min_x); + yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.min_y); + yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.max_x); + yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.max_y); + + // If the CRS is present, add it + if (!column.second.projjson.empty()) { + const auto crs_doc = yyjson_read(column.second.projjson.c_str(), column.second.projjson.size(), 0); + if (!crs_doc) { + yyjson_mut_doc_free(doc); + throw InvalidInputException("Failed to parse CRS JSON"); + } + const auto crs_root = yyjson_doc_get_root(crs_doc); + const auto crs_val = yyjson_val_mut_copy(doc, crs_root); + const auto crs_key = yyjson_mut_strcpy(doc, "projjson"); + yyjson_mut_obj_add(column_json, crs_key, crs_val); + yyjson_doc_free(crs_doc); + } + } + + yyjson_write_err err; + size_t len; + char *json = yyjson_mut_write_opts(doc, 0, nullptr, &len, &err); + if (!json) { + yyjson_mut_doc_free(doc); + throw SerializationException("Failed to write JSON string: %s", err.msg); + } + + // Create a string from the JSON + duckdb_parquet::format::KeyValue kv; + kv.__set_key("geo"); + kv.__set_value(string(json, len)); + + // Free the JSON and the document + free(json); + yyjson_mut_doc_free(doc); + + file_meta_data.key_value_metadata.push_back(kv); + file_meta_data.__isset.key_value_metadata = true; +} + +bool GeoParquetFileMetadata::IsGeometryColumn(const string &column_name) const { + return geometry_columns.find(column_name) != geometry_columns.end(); +} + +unique_ptr GeoParquetFileMetadata::CreateColumnReader(ParquetReader &reader, + const LogicalType &logical_type, + const SchemaElement &s_ele, idx_t schema_idx_p, + idx_t max_define_p, idx_t max_repeat_p, + ClientContext &context) { + + D_ASSERT(IsGeometryColumn(s_ele.name)); + + const auto &column = geometry_columns[s_ele.name]; + + // Get the catalog + auto &catalog = Catalog::GetSystemCatalog(context); + + // WKB encoding + if (logical_type.id() == LogicalTypeId::BLOB && column.geometry_encoding == GeoParquetColumnEncoding::WKB) { + // Look for a conversion function in the catalog + auto &conversion_func_set = + catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_geomfromwkb") + .Cast(); + auto conversion_func = conversion_func_set.functions.GetFunctionByArguments(context, {LogicalType::BLOB}); + + // Create a bound function call expression + auto args = vector>(); + args.push_back(std::move(make_uniq(LogicalType::BLOB, 0))); + auto expr = + make_uniq(conversion_func.return_type, conversion_func, std::move(args), nullptr); + + // Create a child reader + auto child_reader = + ColumnReader::CreateReader(reader, logical_type, s_ele, schema_idx_p, max_define_p, max_repeat_p); + + // Create an expression reader that applies the conversion function to the child reader + return make_uniq(context, std::move(child_reader), std::move(expr)); + } + + // Otherwise, unrecognized encoding + throw NotImplementedException("Unsupported geometry encoding"); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/boolean_column_reader.hpp b/src/duckdb/extension/parquet/include/boolean_column_reader.hpp index 9410ee30..125c548d 100644 --- a/src/duckdb/extension/parquet/include/boolean_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/boolean_column_reader.hpp @@ -46,18 +46,29 @@ struct BooleanParquetValueConversion { static bool PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { plain_data.available(1); + return UnsafePlainRead(plain_data, reader); + } + + static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + PlainRead(plain_data, reader); + } + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return plain_data.check_available((count + 7) / 8); + } + + static bool UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { auto &byte_pos = reader.Cast().byte_pos; bool ret = (*plain_data.ptr >> byte_pos) & 1; - byte_pos++; - if (byte_pos == 8) { + if (++byte_pos == 8) { byte_pos = 0; - plain_data.inc(1); + plain_data.unsafe_inc(1); } return ret; } - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - PlainRead(plain_data, reader); + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + UnsafePlainRead(plain_data, reader); } }; diff --git a/src/duckdb/extension/parquet/include/column_reader.hpp b/src/duckdb/extension/parquet/include/column_reader.hpp index f4e32f39..0a7a17ea 100644 --- a/src/duckdb/extension/parquet/include/column_reader.hpp +++ b/src/duckdb/extension/parquet/include/column_reader.hpp @@ -73,18 +73,43 @@ class ColumnReader { template void PlainTemplated(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, parquet_filter_t &filter, idx_t result_offset, Vector &result) { - auto result_ptr = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - for (idx_t row_idx = 0; row_idx < num_values; row_idx++) { - if (HasDefines() && defines[row_idx + result_offset] != max_define) { - result_mask.SetInvalid(row_idx + result_offset); - continue; + if (HasDefines()) { + if (CONVERSION::PlainAvailable(*plain_data, num_values)) { + PlainTemplatedInternal(*plain_data, defines, num_values, filter, + result_offset, result); + } else { + PlainTemplatedInternal(*plain_data, defines, num_values, filter, + result_offset, result); + } + } else { + if (CONVERSION::PlainAvailable(*plain_data, num_values)) { + PlainTemplatedInternal(*plain_data, defines, num_values, filter, + result_offset, result); + } else { + PlainTemplatedInternal(*plain_data, defines, num_values, filter, + result_offset, result); } - if (filter[row_idx + result_offset]) { - VALUE_TYPE val = CONVERSION::PlainRead(*plain_data, *this); - result_ptr[row_idx + result_offset] = val; + } + } + +private: + template + void PlainTemplatedInternal(ByteBuffer &plain_data, const uint8_t *__restrict defines, const uint64_t num_values, + const parquet_filter_t &filter, const idx_t result_offset, Vector &result) { + const auto result_ptr = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + for (idx_t row_idx = result_offset; row_idx < result_offset + num_values; row_idx++) { + if (HAS_DEFINES && defines[row_idx] != max_define) { + result_mask.SetInvalid(row_idx); + } else if (filter.test(row_idx)) { + result_ptr[row_idx] = + UNSAFE ? CONVERSION::UnsafePlainRead(plain_data, *this) : CONVERSION::PlainRead(plain_data, *this); } else { // there is still some data there that we have to skip over - CONVERSION::PlainSkip(*plain_data, *this); + if (UNSAFE) { + CONVERSION::UnsafePlainSkip(plain_data, *this); + } else { + CONVERSION::PlainSkip(plain_data, *this); + } } } } @@ -110,11 +135,11 @@ class ColumnReader { // applies any skips that were registered using Skip() virtual void ApplyPendingSkips(idx_t num_values); - bool HasDefines() { + bool HasDefines() const { return max_define > 0; } - bool HasRepeats() { + bool HasRepeats() const { return max_repeat > 0; } diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index 3b92a53b..65f89e59 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -22,8 +22,8 @@ class ColumnWriterState { public: virtual ~ColumnWriterState(); - vector definition_levels; - vector repetition_levels; + unsafe_vector definition_levels; + unsafe_vector repetition_levels; vector is_empty; idx_t null_count = 0; @@ -79,12 +79,11 @@ class ColumnWriter { public: //! Create the column writer for a specific type recursively - static unique_ptr CreateWriterRecursive(vector &schemas, - ParquetWriter &writer, const LogicalType &type, - const string &name, vector schema_path, - optional_ptr field_ids, - idx_t max_repeat = 0, idx_t max_define = 1, - bool can_have_nulls = true); + static unique_ptr + CreateWriterRecursive(ClientContext &context, vector &schemas, + ParquetWriter &writer, const LogicalType &type, const string &name, + vector schema_path, optional_ptr field_ids, idx_t max_repeat = 0, + idx_t max_define = 1, bool can_have_nulls = true); virtual unique_ptr InitializeWriteState(duckdb_parquet::format::RowGroup &row_group) = 0; @@ -109,9 +108,9 @@ class ColumnWriter { virtual void FinalizeWrite(ColumnWriterState &state) = 0; protected: - void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, ValidityMask &validity, idx_t count, - uint16_t define_value, uint16_t null_value); - void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat); + void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, + const idx_t count, const uint16_t define_value, const uint16_t null_value) const; + void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat) const; void CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, unique_ptr &compressed_buf); diff --git a/src/duckdb/extension/parquet/include/expression_column_reader.hpp b/src/duckdb/extension/parquet/include/expression_column_reader.hpp new file mode 100644 index 00000000..c94a816d --- /dev/null +++ b/src/duckdb/extension/parquet/include/expression_column_reader.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// expression_column_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "column_reader.hpp" +#include "templated_column_reader.hpp" + +namespace duckdb { + +//! A column reader that executes an expression over a child reader +class ExpressionColumnReader : public ColumnReader { +public: + static constexpr const PhysicalType TYPE = PhysicalType::INVALID; + +public: + ExpressionColumnReader(ClientContext &context, unique_ptr child_reader, unique_ptr expr); + + unique_ptr child_reader; + DataChunk intermediate_chunk; + unique_ptr expr; + ExpressionExecutor executor; + +public: + unique_ptr Stats(idx_t row_group_idx_p, const vector &columns) override; + void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; + + idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, + Vector &result) override; + + void Skip(idx_t num_values) override; + idx_t GroupRowsAvailable() override; + + uint64_t TotalCompressedSize() override { + return child_reader->TotalCompressedSize(); + } + + idx_t FileOffset() const override { + return child_reader->FileOffset(); + } + + void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { + child_reader->RegisterPrefetch(transport, allow_merge); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp new file mode 100644 index 00000000..24b65ab6 --- /dev/null +++ b/src/duckdb/extension/parquet/include/geo_parquet.hpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// geo_parquet.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "column_writer.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "parquet_types.h" + +namespace duckdb { + +enum class WKBGeometryType : uint16_t { + POINT = 1, + LINESTRING = 2, + POLYGON = 3, + MULTIPOINT = 4, + MULTILINESTRING = 5, + MULTIPOLYGON = 6, + GEOMETRYCOLLECTION = 7, + + POINT_Z = 1001, + LINESTRING_Z = 1002, + POLYGON_Z = 1003, + MULTIPOINT_Z = 1004, + MULTILINESTRING_Z = 1005, + MULTIPOLYGON_Z = 1006, + GEOMETRYCOLLECTION_Z = 1007, +}; + +struct WKBGeometryTypes { + static const char *ToString(WKBGeometryType type); +}; + +struct GeometryBounds { + double min_x = NumericLimits::Maximum(); + double max_x = NumericLimits::Minimum(); + double min_y = NumericLimits::Maximum(); + double max_y = NumericLimits::Minimum(); + + GeometryBounds() = default; + + void Combine(const GeometryBounds &other) { + min_x = std::min(min_x, other.min_x); + max_x = std::max(max_x, other.max_x); + min_y = std::min(min_y, other.min_y); + max_y = std::max(max_y, other.max_y); + } + + void Combine(const double &x, const double &y) { + min_x = std::min(min_x, x); + max_x = std::max(max_x, x); + min_y = std::min(min_y, y); + max_y = std::max(max_y, y); + } + + void Combine(const double &min_x, const double &max_x, const double &min_y, const double &max_y) { + this->min_x = std::min(this->min_x, min_x); + this->max_x = std::max(this->max_x, max_x); + this->min_y = std::min(this->min_y, min_y); + this->max_y = std::max(this->max_y, max_y); + } +}; + +//------------------------------------------------------------------------------ +// GeoParquetMetadata +//------------------------------------------------------------------------------ +class ParquetReader; +class ColumnReader; +class ClientContext; +class ExpressionExecutor; + +enum class GeoParquetColumnEncoding : uint8_t { + WKB = 1, + POINT, + LINESTRING, + POLYGON, + MULTIPOINT, + MULTILINESTRING, + MULTIPOLYGON, +}; + +struct GeoParquetColumnMetadata { + // The encoding of the geometry column + GeoParquetColumnEncoding geometry_encoding; + + // The geometry types that are present in the column + set geometry_types; + + // The bounds of the geometry column + GeometryBounds bbox; + + // The crs of the geometry column (if any) in PROJJSON format + string projjson; +}; + +class GeoParquetColumnMetadataWriter { + unique_ptr executor; + DataChunk input_chunk; + DataChunk result_chunk; + + unique_ptr type_expr; + unique_ptr flag_expr; + unique_ptr bbox_expr; + +public: + explicit GeoParquetColumnMetadataWriter(ClientContext &context); + void Update(GeoParquetColumnMetadata &meta, Vector &vector, idx_t count); +}; + +struct GeoParquetFileMetadata { +public: + // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not + // available. + static unique_ptr TryRead(const duckdb_parquet::format::FileMetaData &file_meta_data, + ClientContext &context); + void Write(duckdb_parquet::format::FileMetaData &file_meta_data) const; + +public: + // Default to 1.1.0 for now + string version = "1.1.0"; + string primary_geometry_column; + unordered_map geometry_columns; + + unique_ptr CreateColumnReader(ParquetReader &reader, const LogicalType &logical_type, + const duckdb_parquet::format::SchemaElement &s_ele, idx_t schema_idx_p, + idx_t max_define_p, idx_t max_repeat_p, ClientContext &context); + + bool IsGeometryColumn(const string &column_name) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_crypto.hpp b/src/duckdb/extension/parquet/include/parquet_crypto.hpp index 3a3185df..b4aed9d0 100644 --- a/src/duckdb/extension/parquet/include/parquet_crypto.hpp +++ b/src/duckdb/extension/parquet/include/parquet_crypto.hpp @@ -9,6 +9,7 @@ #pragma once #include "parquet_types.h" +#include "duckdb/common/encryption_state.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" @@ -62,26 +63,30 @@ class ParquetEncryptionConfig { class ParquetCrypto { public: //! Encrypted modules - static constexpr uint32_t LENGTH_BYTES = 4; - static constexpr uint32_t NONCE_BYTES = 12; - static constexpr uint32_t TAG_BYTES = 16; + static constexpr idx_t LENGTH_BYTES = 4; + static constexpr idx_t NONCE_BYTES = 12; + static constexpr idx_t TAG_BYTES = 16; //! Block size we encrypt/decrypt - static constexpr uint32_t CRYPTO_BLOCK_SIZE = 4096; + static constexpr idx_t CRYPTO_BLOCK_SIZE = 4096; + static constexpr idx_t BLOCK_SIZE = 16; public: //! Decrypt and read a Thrift object from the transport protocol - static uint32_t Read(TBase &object, TProtocol &iprot, const string &key); + static uint32_t Read(TBase &object, TProtocol &iprot, const string &key, const EncryptionUtil &encryption_util_p); //! Encrypt and write a Thrift object to the transport protocol - static uint32_t Write(const TBase &object, TProtocol &oprot, const string &key); + static uint32_t Write(const TBase &object, TProtocol &oprot, const string &key, + const EncryptionUtil &encryption_util_p); //! Decrypt and read a buffer - static uint32_t ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key); + static uint32_t ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key, + const EncryptionUtil &encryption_util_p); //! Encrypt and write a buffer to a file static uint32_t WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, - const string &key); + const string &key, const EncryptionUtil &encryption_util_p); public: static void AddKey(ClientContext &context, const FunctionParameters ¶meters); + static bool ValidKey(const std::string &key); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp b/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp index 42debd2c..4f189bbc 100644 --- a/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp +++ b/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp @@ -32,6 +32,9 @@ class ParquetDecimalUtils { if (size > sizeof(PHYSICAL_TYPE)) { for (idx_t i = sizeof(PHYSICAL_TYPE); i < size; i++) { auto byte = *(pointer + (size - i - 1)); + if (!positive) { + byte ^= 0xFF; + } if (byte != 0) { throw InvalidInputException("Invalid decimal encoding in Parquet file"); } diff --git a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp index 01d316dc..48b6448d 100644 --- a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp +++ b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp @@ -10,9 +10,9 @@ #include "duckdb.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" +#include "geo_parquet.hpp" #endif #include "parquet_types.h" - namespace duckdb { //! ParquetFileMetadataCache @@ -20,8 +20,9 @@ class ParquetFileMetadataCache : public ObjectCacheEntry { public: ParquetFileMetadataCache() : metadata(nullptr) { } - ParquetFileMetadataCache(unique_ptr file_metadata, time_t r_time) - : metadata(std::move(file_metadata)), read_time(r_time) { + ParquetFileMetadataCache(unique_ptr file_metadata, time_t r_time, + unique_ptr geo_metadata) + : metadata(std::move(file_metadata)), read_time(r_time), geo_metadata(std::move(geo_metadata)) { } ~ParquetFileMetadataCache() override = default; @@ -32,6 +33,9 @@ class ParquetFileMetadataCache : public ObjectCacheEntry { //! read time time_t read_time; + //! GeoParquet metadata + unique_ptr geo_metadata; + public: static string ObjectType() { return "parquet_metadata"; diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp index 6e65d46d..ef8dcaf8 100644 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ b/src/duckdb/extension/parquet/include/parquet_reader.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/common/common.hpp" +#include "duckdb/common/encryption_state.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/multi_file_reader.hpp" #include "duckdb/common/multi_file_reader_options.hpp" @@ -88,6 +89,7 @@ struct ParquetOptions { bool binary_as_string = false; bool file_row_number = false; shared_ptr encryption_config; + bool debug_use_openssl = true; MultiFileReaderOptions file_options; vector schema; @@ -97,11 +99,28 @@ struct ParquetOptions { static ParquetOptions Deserialize(Deserializer &deserializer); }; +struct ParquetUnionData { + ~ParquetUnionData(); + + string file_name; + vector names; + vector types; + ParquetOptions options; + shared_ptr metadata; + unique_ptr reader; + + const string &GetFileName() { + return file_name; + } +}; + class ParquetReader { public: - ParquetReader(ClientContext &context, string file_name, ParquetOptions parquet_options); - ParquetReader(ClientContext &context, ParquetOptions parquet_options, - shared_ptr metadata); + using UNION_READER_DATA = unique_ptr; + +public: + ParquetReader(ClientContext &context, string file_name, ParquetOptions parquet_options, + shared_ptr metadata = nullptr); ~ParquetReader(); FileSystem &fs; @@ -113,16 +132,37 @@ class ParquetReader { ParquetOptions parquet_options; MultiFileReaderData reader_data; unique_ptr root_reader; + shared_ptr encryption_util; //! Index of the file_row_number column idx_t file_row_number_idx = DConstants::INVALID_INDEX; //! Parquet schema for the generated columns vector generated_column_schema; + //! Table column names - set when using COPY tbl FROM file.parquet + vector table_columns; public: - void InitializeScan(ParquetReaderScanState &state, vector groups_to_read); + void InitializeScan(ClientContext &context, ParquetReaderScanState &state, vector groups_to_read); void Scan(ParquetReaderScanState &state, DataChunk &output); + static unique_ptr StoreUnionReader(unique_ptr reader_p, idx_t file_idx) { + auto result = make_uniq(); + result->file_name = reader_p->file_name; + if (file_idx == 0) { + result->names = reader_p->names; + result->types = reader_p->return_types; + result->options = reader_p->parquet_options; + result->metadata = reader_p->metadata; + result->reader = std::move(reader_p); + } else { + result->names = std::move(reader_p->names); + result->types = std::move(reader_p->return_types); + result->options = std::move(reader_p->parquet_options); + result->metadata = std::move(reader_p->metadata); + } + return result; + } + idx_t NumRows(); idx_t NumRowGroups(); @@ -149,13 +189,20 @@ class ParquetReader { return return_types; } + static unique_ptr ReadStatistics(ClientContext &context, ParquetOptions parquet_options, + shared_ptr metadata, const string &name); + private: - void InitializeSchema(); + //! Construct a parquet reader but **do not** open a file, used in ReadStatistics only + ParquetReader(ClientContext &context, ParquetOptions parquet_options, + shared_ptr metadata); + + void InitializeSchema(ClientContext &context); bool ScanInternal(ParquetReaderScanState &state, DataChunk &output); - unique_ptr CreateReader(); + unique_ptr CreateReader(ClientContext &context); - unique_ptr CreateReaderRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, - idx_t &next_schema_idx, idx_t &next_file_idx); + unique_ptr CreateReaderRecursive(ClientContext &context, idx_t depth, idx_t max_define, + idx_t max_repeat, idx_t &next_schema_idx, idx_t &next_file_idx); const duckdb_parquet::format::RowGroup &GetGroup(ParquetReaderScanState &state); uint64_t GetGroupCompressedSize(ParquetReaderScanState &state); idx_t GetGroupOffset(ParquetReaderScanState &state); diff --git a/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp index 125edf1d..27093388 100644 --- a/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp +++ b/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp @@ -7,10 +7,10 @@ //===----------------------------------------------------------------------===// #pragma once +#include "decode_utils.hpp" #include "parquet_types.h" -#include "thrift_tools.hpp" #include "resizable_buffer.hpp" -#include "decode_utils.hpp" +#include "thrift_tools.hpp" namespace duckdb { @@ -35,7 +35,7 @@ class RleBpDecoder { while (values_read < batch_size) { if (repeat_count_ > 0) { int repeat_batch = MinValue(batch_size - values_read, static_cast(repeat_count_)); - std::fill(values + values_read, values + values_read + repeat_batch, static_cast(current_value_)); + std::fill_n(values + values_read, repeat_batch, static_cast(current_value_)); repeat_count_ -= repeat_batch; values_read += repeat_batch; } else if (literal_count_ > 0) { diff --git a/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp b/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp index e9115a82..029dd06e 100644 --- a/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp +++ b/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp @@ -16,7 +16,7 @@ namespace duckdb { class RleBpEncoder { public: - RleBpEncoder(uint32_t bit_width); + explicit RleBpEncoder(uint32_t bit_width); public: //! NOTE: Prepare is only required if a byte count is required BEFORE writing diff --git a/src/duckdb/extension/parquet/include/parquet_timestamp.hpp b/src/duckdb/extension/parquet/include/parquet_timestamp.hpp index d9c33b4d..8631af99 100644 --- a/src/duckdb/extension/parquet/include/parquet_timestamp.hpp +++ b/src/duckdb/extension/parquet/include/parquet_timestamp.hpp @@ -17,14 +17,22 @@ struct Int96 { }; timestamp_t ImpalaTimestampToTimestamp(const Int96 &raw_ts); +timestamp_ns_t ImpalaTimestampToTimestampNS(const Int96 &raw_ts); Int96 TimestampToImpalaTimestamp(timestamp_t &ts); + timestamp_t ParquetTimestampMicrosToTimestamp(const int64_t &raw_ts); timestamp_t ParquetTimestampMsToTimestamp(const int64_t &raw_ts); timestamp_t ParquetTimestampNsToTimestamp(const int64_t &raw_ts); + +timestamp_ns_t ParquetTimestampMsToTimestampNs(const int64_t &raw_ms); +timestamp_ns_t ParquetTimestampUsToTimestampNs(const int64_t &raw_us); +timestamp_ns_t ParquetTimestampNsToTimestampNs(const int64_t &raw_ns); + date_t ParquetIntToDate(const int32_t &raw_date); dtime_t ParquetIntToTimeMs(const int32_t &raw_time); dtime_t ParquetIntToTime(const int64_t &raw_time); dtime_t ParquetIntToTimeNs(const int64_t &raw_time); + dtime_tz_t ParquetIntToTimeMsTZ(const int32_t &raw_time); dtime_tz_t ParquetIntToTimeTZ(const int64_t &raw_time); dtime_tz_t ParquetIntToTimeNsTZ(const int64_t &raw_time); diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index 137a946c..297d2efa 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/common/common.hpp" +#include "duckdb/common/encryption_state.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/serializer/buffered_file_writer.hpp" @@ -20,6 +21,7 @@ #include "column_writer.hpp" #include "parquet_types.h" +#include "geo_parquet.hpp" #include "thrift/protocol/TCompactProtocol.h" namespace duckdb { @@ -61,11 +63,11 @@ struct FieldID { class ParquetWriter { public: - ParquetWriter(FileSystem &fs, string file_name, vector types, vector names, - duckdb_parquet::format::CompressionCodec::type codec, ChildFieldIDs field_ids, + ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, + vector names, duckdb_parquet::format::CompressionCodec::type codec, ChildFieldIDs field_ids, const vector> &kv_metadata, shared_ptr encryption_config, double dictionary_compression_ratio_threshold, - optional_idx compression_level); + optional_idx compression_level, bool debug_use_openssl); public: void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result); @@ -85,6 +87,9 @@ class ParquetWriter { duckdb_parquet::format::Type::type GetType(idx_t schema_idx) { return file_meta_data.schema[schema_idx].type; } + LogicalType GetSQLType(idx_t schema_idx) const { + return sql_types[schema_idx]; + } BufferedFileWriter &GetWriter() { return *writer; } @@ -98,15 +103,20 @@ class ParquetWriter { optional_idx CompressionLevel() const { return compression_level; } - - static CopyTypeSupport TypeIsSupported(const LogicalType &type); + idx_t NumberOfRowGroups() { + lock_guard glock(lock); + return file_meta_data.row_groups.size(); + } uint32_t Write(const duckdb_apache::thrift::TBase &object); uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size); + GeoParquetFileMetadata &GetGeoParquetData(); + + static bool TryGetParquetType(const LogicalType &duckdb_type, + optional_ptr type = nullptr); + private: - static CopyTypeSupport DuckDBTypeToParquetTypeInternal(const LogicalType &duckdb_type, - duckdb_parquet::format::Type::type &type); string file_name; vector sql_types; vector column_names; @@ -115,6 +125,8 @@ class ParquetWriter { shared_ptr encryption_config; double dictionary_compression_ratio_threshold; optional_idx compression_level; + bool debug_use_openssl; + shared_ptr encryption_util; unique_ptr writer; std::shared_ptr protocol; @@ -122,6 +134,8 @@ class ParquetWriter { std::mutex lock; vector> column_writers; + + unique_ptr geoparquet_data; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/resizable_buffer.hpp b/src/duckdb/extension/parquet/include/resizable_buffer.hpp index 39ee9338..65b639ba 100644 --- a/src/duckdb/extension/parquet/include/resizable_buffer.hpp +++ b/src/duckdb/extension/parquet/include/resizable_buffer.hpp @@ -25,50 +25,72 @@ class ByteBuffer { // on to the 10 thousandth impl uint64_t len = 0; public: - void inc(uint64_t increment) { + void inc(const uint64_t increment) { available(increment); + unsafe_inc(increment); + } + + void unsafe_inc(const uint64_t increment) { len -= increment; ptr += increment; } template T read() { - T val = get(); - inc(sizeof(T)); + available(sizeof(T)); + return unsafe_read(); + } + + template + T unsafe_read() { + T val = unsafe_get(); + unsafe_inc(sizeof(T)); return val; } template T get() { available(sizeof(T)); - T val = Load(ptr); - return val; + return unsafe_get(); } - void copy_to(char *dest, uint64_t len) { + template + T unsafe_get() { + return Load(ptr); + } + + void copy_to(char *dest, const uint64_t len) const { available(len); + unsafe_copy_to(dest, len); + } + + void unsafe_copy_to(char *dest, const uint64_t len) const { std::memcpy(dest, ptr, len); } - void zero() { + void zero() const { std::memset(ptr, 0, len); } - void available(uint64_t req_len) { - if (req_len > len) { + void available(const uint64_t req_len) const { + if (!check_available(req_len)) { throw std::runtime_error("Out of buffer"); } } + + bool check_available(const uint64_t req_len) const { + return req_len <= len; + } }; class ResizeableBuffer : public ByteBuffer { public: ResizeableBuffer() { } - ResizeableBuffer(Allocator &allocator, uint64_t new_size) { + ResizeableBuffer(Allocator &allocator, const uint64_t new_size) { resize(allocator, new_size); } - void resize(Allocator &allocator, uint64_t new_size) { + void resize(Allocator &allocator, const uint64_t new_size) { len = new_size; if (new_size == 0) { return; diff --git a/src/duckdb/extension/parquet/include/string_column_reader.hpp b/src/duckdb/extension/parquet/include/string_column_reader.hpp index df266015..f67bbd9d 100644 --- a/src/duckdb/extension/parquet/include/string_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/string_column_reader.hpp @@ -16,8 +16,11 @@ struct StringParquetValueConversion { static string_t DictRead(ByteBuffer &dict, uint32_t &offset, ColumnReader &reader); static string_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader); - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader); + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count); + static string_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader); + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader); }; class StringColumnReader : public TemplatedColumnReader { @@ -28,7 +31,7 @@ class StringColumnReader : public TemplatedColumnReader dict_strings; + unsafe_unique_ptr dict_strings; idx_t fixed_width_string_length; idx_t delta_offset = 0; diff --git a/src/duckdb/extension/parquet/include/templated_column_reader.hpp b/src/duckdb/extension/parquet/include/templated_column_reader.hpp index b98371c6..e29d2a98 100644 --- a/src/duckdb/extension/parquet/include/templated_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/templated_column_reader.hpp @@ -27,6 +27,18 @@ struct TemplatedParquetValueConversion { static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { plain_data.inc(sizeof(VALUE_TYPE)); } + + static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { + return plain_data.check_available(count * sizeof(VALUE_TYPE)); + } + + static VALUE_TYPE UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { + return plain_data.unsafe_read(); + } + + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + plain_data.unsafe_inc(sizeof(VALUE_TYPE)); + } }; template @@ -60,29 +72,39 @@ class TemplatedColumnReader : public ColumnReader { throw IOException( "Parquet file is likely corrupted, cannot have dictionary offsets without seeing a dictionary first."); } - auto result_ptr = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); + if (HasDefines()) { + OffsetsInternal(*dict, offsets, defines, num_values, filter, result_offset, result); + } else { + OffsetsInternal(*dict, offsets, defines, num_values, filter, result_offset, result); + } + } + void Plain(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, parquet_filter_t &filter, + idx_t result_offset, Vector &result) override { + PlainTemplated(std::move(plain_data), defines, num_values, filter, result_offset, + result); + } + +private: + template + void OffsetsInternal(ResizeableBuffer &dict_ref, uint32_t *__restrict offsets, const uint8_t *__restrict defines, + const uint64_t num_values, const parquet_filter_t &filter, const idx_t result_offset, + Vector &result) { + const auto result_ptr = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); idx_t offset_idx = 0; - for (idx_t row_idx = 0; row_idx < num_values; row_idx++) { - if (HasDefines() && defines[row_idx + result_offset] != max_define) { - result_mask.SetInvalid(row_idx + result_offset); + for (idx_t row_idx = result_offset; row_idx < result_offset + num_values; row_idx++) { + if (HAS_DEFINES && defines[row_idx] != max_define) { + result_mask.SetInvalid(row_idx); continue; } - if (filter[row_idx + result_offset]) { - VALUE_TYPE val = VALUE_CONVERSION::DictRead(*dict, offsets[offset_idx++], *this); - result_ptr[row_idx + result_offset] = val; + if (filter.test(row_idx)) { + result_ptr[row_idx] = VALUE_CONVERSION::DictRead(dict_ref, offsets[offset_idx++], *this); } else { offset_idx++; } } } - - void Plain(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, parquet_filter_t &filter, - idx_t result_offset, Vector &result) override { - PlainTemplated(std::move(plain_data), defines, num_values, filter, result_offset, - result); - } }; template ()); + } + + static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { + plain_data.unsafe_inc(sizeof(PARQUET_PHYSICAL_TYPE)); + } }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp index d6bb7f1b..070c381d 100644 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ b/src/duckdb/extension/parquet/parquet_crypto.cpp @@ -4,8 +4,9 @@ #include "thrift_tools.hpp" #ifndef DUCKDB_AMALGAMATION +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/helper.hpp" -#include "duckdb/common/common.hpp" +#include "duckdb/common/types/blob.hpp" #include "duckdb/storage/arena_allocator.hpp" #endif @@ -45,6 +46,7 @@ ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context_p) : con ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context_p, const Value &arg) : ParquetEncryptionConfig(context_p) { + if (arg.type().id() != LogicalTypeId::STRUCT) { throw BinderException("Parquet encryption_config must be of type STRUCT"); } @@ -80,21 +82,16 @@ const string &ParquetEncryptionConfig::GetFooterKey() const { return keys.GetKey(footer_key); } -using duckdb_apache::thrift::transport::TTransport; -using AESGCMState = duckdb_mbedtls::MbedTlsWrapper::AESGCMState; using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT; - -static void GenerateNonce(const data_ptr_t nonce) { - duckdb_mbedtls::MbedTlsWrapper::GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES); -} +using duckdb_apache::thrift::transport::TTransport; //! Encryption wrapper for a transport protocol class EncryptionTransport : public TTransport { public: - EncryptionTransport(TProtocol &prot_p, const string &key) - : prot(prot_p), trans(*prot.getTransport()), aes(key), + EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) + : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState()), allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) { - Initialize(); + Initialize(key); } bool isOpen() const override { @@ -117,38 +114,41 @@ class EncryptionTransport : public TTransport { // Write length const auto ciphertext_length = allocator.SizeInBytes(); const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES; - trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES); - // Write nonce + trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES); + // Write nonce at beginning of encrypted chunk trans.write(nonce, ParquetCrypto::NONCE_BYTES); - // Encrypt and write data data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE]; auto current = allocator.GetTail(); + + // Loop through the whole chunk while (current != nullptr) { for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) { auto next = MinValue(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE); auto write_size = - aes.Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE); + aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE); trans.write(aes_buffer, write_size); } current = current->prev; } - // Finalize the last encrypted data and write tag + // Finalize the last encrypted data data_t tag[ParquetCrypto::TAG_BYTES]; - auto write_size = aes.Finalize(aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE, tag, ParquetCrypto::TAG_BYTES); + auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES); trans.write(aes_buffer, write_size); + // Write tag for verification trans.write(tag, ParquetCrypto::TAG_BYTES); return ParquetCrypto::LENGTH_BYTES + total_length; } private: - void Initialize() { - // Generate nonce and initialize AES - GenerateNonce(nonce); - aes.InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES); + void Initialize(const string &key) { + // Generate Nonce + aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES); + // Initialize Encryption + aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, &key); } private: @@ -156,8 +156,8 @@ class EncryptionTransport : public TTransport { TProtocol &prot; TTransport &trans; - //! AES context - AESGCMState aes; + //! AES context and buffers + shared_ptr aes; //! Nonce created by Initialize() data_t nonce[ParquetCrypto::NONCE_BYTES]; @@ -169,11 +169,11 @@ class EncryptionTransport : public TTransport { //! Decryption wrapper for a transport protocol class DecryptionTransport : public TTransport { public: - DecryptionTransport(TProtocol &prot_p, const string &key) - : prot(prot_p), trans(*prot.getTransport()), aes(key), read_buffer_size(0), read_buffer_offset(0) { - Initialize(); + DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) + : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState()), + read_buffer_size(0), read_buffer_offset(0) { + Initialize(key); } - uint32_t read_virt(uint8_t *buf, uint32_t len) override { const uint32_t result = len; @@ -183,10 +183,9 @@ class DecryptionTransport : public TTransport { while (len != 0) { if (read_buffer_offset == read_buffer_size) { - ReadBlock(); + ReadBlock(buf); } const auto next = MinValue(read_buffer_size - read_buffer_offset, len); - memcpy(buf, read_buffer + read_buffer_offset, next); read_buffer_offset += next; buf += next; len -= next; @@ -196,19 +195,29 @@ class DecryptionTransport : public TTransport { } uint32_t Finalize() { + if (read_buffer_offset != read_buffer_size) { - throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer"); + throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" + "read buffer offset: %d, read buffer size: %d", + read_buffer_offset, read_buffer_size); } data_t computed_tag[ParquetCrypto::TAG_BYTES]; - if (aes.Finalize(read_buffer, AESGCMState::BLOCK_SIZE, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { - throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in AES context"); - } - data_t read_tag[ParquetCrypto::TAG_BYTES]; - transport_remaining -= trans.read(read_tag, ParquetCrypto::TAG_BYTES); - if (memcmp(computed_tag, read_tag, ParquetCrypto::TAG_BYTES) != 0) { - throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?"); + if (aes->IsOpenSSL()) { + // For OpenSSL, the obtained tag is an input argument for aes->Finalize() + transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES); + if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { + throw InternalException( + "DecryptionTransport::Finalize was called with bytes remaining in AES context out"); + } + } else { + // For mbedtls, computed_tag is an output argument for aes->Finalize() + if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { + throw InternalException( + "DecryptionTransport::Finalize was called with bytes remaining in AES context out"); + } + VerifyTag(computed_tag); } if (transport_remaining != 0) { @@ -227,45 +236,53 @@ class DecryptionTransport : public TTransport { } private: - void Initialize() { + void Initialize(const string &key) { // Read encoded length (don't add to read_bytes) data_t length_buf[ParquetCrypto::LENGTH_BYTES]; trans.read(length_buf, ParquetCrypto::LENGTH_BYTES); total_bytes = Load(length_buf); transport_remaining = total_bytes; - // Read nonce and initialize AES transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES); - aes.InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES); + // check whether context is initialized + aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, &key); } - void ReadBlock() { + void ReadBlock(uint8_t *buf) { // Read from transport into read_buffer at one AES block size offset (up to the tag) read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES); - transport_remaining -= trans.read(read_buffer + AESGCMState::BLOCK_SIZE, read_buffer_size); + transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size); // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer) #ifdef DEBUG - auto size = aes.Process(read_buffer + AESGCMState::BLOCK_SIZE, read_buffer_size, read_buffer, - ParquetCrypto::CRYPTO_BLOCK_SIZE + AESGCMState::BLOCK_SIZE); + auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, + ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); D_ASSERT(size == read_buffer_size); #else - aes.Process(read_buffer + AESGCMState::BLOCK_SIZE, read_buffer_size, read_buffer, - ParquetCrypto::CRYPTO_BLOCK_SIZE + AESGCMState::BLOCK_SIZE); + aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, + ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); #endif read_buffer_offset = 0; } + void VerifyTag(data_t *computed_tag) { + data_t read_tag[ParquetCrypto::TAG_BYTES]; + transport_remaining -= trans.read(read_tag, ParquetCrypto::TAG_BYTES); + if (memcmp(computed_tag, read_tag, ParquetCrypto::TAG_BYTES) != 0) { + throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?"); + } + } + private: //! Protocol and corresponding transport that we're wrapping TProtocol &prot; TTransport &trans; //! AES context and buffers - AESGCMState aes; + shared_ptr aes; //! We read/decrypt big blocks at a time - data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + AESGCMState::BLOCK_SIZE]; + data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE]; uint32_t read_buffer_size; uint32_t read_buffer_offset; @@ -298,10 +315,10 @@ class SimpleReadTransport : public TTransport { uint32_t read_buffer_offset; }; -uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key) { - // Create decryption protocol +uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key, + const EncryptionUtil &encryption_util_p) { TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key)); + auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key, encryption_util_p)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong @@ -316,10 +333,11 @@ uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key) return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES; } -uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key) { +uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key, + const EncryptionUtil &encryption_util_p) { // Create encryption protocol TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key)); + auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key, encryption_util_p)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the object in memory @@ -330,10 +348,10 @@ uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const strin } uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, - const string &key) { + const string &key, const EncryptionUtil &encryption_util_p) { // Create decryption protocol TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key)); + auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key, encryption_util_p)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // Read buffer @@ -344,11 +362,11 @@ uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, cons } uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, - const string &key) { + const string &key, const EncryptionUtil &encryption_util_p) { // FIXME: we know the size upfront so we could do a streaming write instead of this // Create encryption protocol TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key)); + auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key, encryption_util_p)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the data in memory @@ -358,15 +376,45 @@ uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffe return etrans.Finalize(); } +bool ParquetCrypto::ValidKey(const std::string &key) { + switch (key.size()) { + case 16: + case 24: + case 32: + return true; + default: + return false; + } +} + +string Base64Decode(const string &key) { + auto result_size = Blob::FromBase64Size(key); + auto output = duckdb::unique_ptr(new unsigned char[result_size]); + Blob::FromBase64(key, output.get(), result_size); + string decoded_key(reinterpret_cast(output.get()), result_size); + return decoded_key; +} + void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters ¶meters) { const auto &key_name = StringValue::Get(parameters.values[0]); const auto &key = StringValue::Get(parameters.values[1]); - if (!AESGCMState::ValidKey(key)) { - throw InvalidInputException( - "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)"); - } + auto &keys = ParquetKeys::Get(context); - keys.AddKey(key_name, key); + if (ValidKey(key)) { + keys.AddKey(key_name, key); + } else { + string decoded_key; + try { + decoded_key = Base64Decode(key); + } catch (const ConversionException &e) { + throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string"); + } + if (!ValidKey(decoded_key)) { + throw InvalidInputException( + "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)"); + } + keys.AddKey(key_name, decoded_key); + } } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 6a978d9d..596fed87 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -4,6 +4,12 @@ #include "cast_column_reader.hpp" #include "duckdb.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "geo_parquet.hpp" #include "parquet_crypto.hpp" #include "parquet_metadata.hpp" #include "parquet_reader.hpp" @@ -17,15 +23,16 @@ #include #include #ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/helper.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/common/constants.hpp" #include "duckdb/common/enums/file_compression_type.hpp" #include "duckdb/common/file_system.hpp" +#include "duckdb/common/helper.hpp" #include "duckdb/common/multi_file_reader.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/function/copy_function.hpp" #include "duckdb/function/pragma_function.hpp" #include "duckdb/function/table_function.hpp" @@ -37,6 +44,7 @@ #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" @@ -52,10 +60,12 @@ struct ParquetReadBindData : public TableFunctionData { atomic chunk_count; vector names; vector types; + //! Table column names - set when using COPY tbl FROM file.parquet + vector table_columns; // The union readers are created (when parquet union_by_name option is on) during binding // Those readers can be re-used during ParquetParallelStateNext - vector> union_readers; + vector> union_readers; // These come from the initial_reader, but need to be stored in case the initial_reader is removed by a filter idx_t initial_file_cardinality; @@ -70,6 +80,9 @@ struct ParquetReadBindData : public TableFunctionData { initial_file_row_groups = initial_reader->NumRowGroups(); parquet_options = initial_reader->parquet_options; } + void Initialize(ClientContext &, unique_ptr &union_data) { + Initialize(std::move(union_data->reader)); + } }; struct ParquetReadLocalState : public LocalTableFunctionState { @@ -94,6 +107,16 @@ struct ParquetFileReaderData { explicit ParquetFileReaderData(shared_ptr reader_p) : reader(std::move(reader_p)), file_state(ParquetFileState::OPEN), file_mutex(make_uniq()) { } + // Create data for an existing reader + explicit ParquetFileReaderData(unique_ptr union_data_p) : file_mutex(make_uniq()) { + if (union_data_p->reader) { + reader = std::move(union_data_p->reader); + file_state = ParquetFileState::OPEN; + } else { + union_data = std::move(union_data_p); + file_state = ParquetFileState::UNOPENED; + } + } //! Currently opened reader for the file shared_ptr reader; @@ -101,21 +124,33 @@ struct ParquetFileReaderData { ParquetFileState file_state; //! Mutexes to wait for the file when it is being opened unique_ptr file_mutex; + //! Parquet options for opening the file + unique_ptr union_data; //! (only set when file_state is UNOPENED) the file to be opened string file_to_be_opened; }; struct ParquetReadGlobalState : public GlobalTableFunctionState { + explicit ParquetReadGlobalState(MultiFileList &file_list_p) : file_list(file_list_p) { + } + explicit ParquetReadGlobalState(unique_ptr owned_file_list_p) + : file_list(*owned_file_list_p), owned_file_list(std::move(owned_file_list_p)) { + } + + //! The file list to scan + MultiFileList &file_list; //! The scan over the file_list MultiFileListScanData file_list_scan; + //! Owned multi file list - if filters have been dynamically pushed into the reader + unique_ptr owned_file_list; unique_ptr multi_file_reader_state; mutex lock; //! The current set of parquet readers - vector readers; + vector> readers; //! Signal to other threads that a file failed to open, letting every thread abort. bool error_opening_file = false; @@ -131,7 +166,7 @@ struct ParquetReadGlobalState : public GlobalTableFunctionState { vector projection_ids; vector scanned_types; vector column_ids; - TableFilterSet *filters; + optional_ptr filters; idx_t MaxThreads() const override { return max_threads; @@ -155,10 +190,14 @@ struct ParquetWriteBindData : public TableFunctionData { //! How/Whether to encrypt the data shared_ptr encryption_config; + bool debug_use_openssl = true; //! Dictionary compression is applied only if the compression ratio exceeds this threshold double dictionary_compression_ratio_threshold = 1.0; + //! After how many row groups to rotate to a new file + optional_idx row_groups_per_file; + ChildFieldIDs field_ids; //! The compression level, higher value is more optional_idx compression_level; @@ -191,6 +230,7 @@ BindInfo ParquetGetBindInfo(const optional_ptr bind_data) { bind_info.InsertOption("file_path", Value::LIST(LogicalType::VARCHAR, file_path)); bind_info.InsertOption("binary_as_string", Value::BOOLEAN(parquet_bind.parquet_options.binary_as_string)); bind_info.InsertOption("file_row_number", Value::BOOLEAN(parquet_bind.parquet_options.file_row_number)); + bind_info.InsertOption("debug_use_openssl", Value::BOOLEAN(parquet_bind.parquet_options.debug_use_openssl)); parquet_bind.parquet_options.file_options.AddBatchInfo(bind_info); // LCOV_EXCL_STOP return bind_info; @@ -251,6 +291,7 @@ static void InitializeParquetReader(ParquetReader &reader, const ParquetReadBind auto &parquet_options = bind_data.parquet_options; auto &reader_data = reader.reader_data; + reader.table_columns = bind_data.table_columns; // Mark the file in the file list we are scanning here reader_data.file_list_idx = file_idx; @@ -352,6 +393,7 @@ class ParquetScanFunction { table_function.table_scan_progress = ParquetProgress; table_function.named_parameters["binary_as_string"] = LogicalType::BOOLEAN; table_function.named_parameters["file_row_number"] = LogicalType::BOOLEAN; + table_function.named_parameters["debug_use_openssl"] = LogicalType::BOOLEAN; table_function.named_parameters["compression"] = LogicalType::VARCHAR; table_function.named_parameters["schema"] = LogicalType::MAP(LogicalType::INTEGER, LogicalType::STRUCT({{{"name", LogicalType::VARCHAR}, @@ -388,6 +430,8 @@ class ParquetScanFunction { parquet_options.binary_as_string = GetBooleanArgument(option); } else if (loption == "file_row_number") { parquet_options.file_row_number = GetBooleanArgument(option); + } else if (loption == "debug_use_openssl") { + parquet_options.debug_use_openssl = GetBooleanArgument(option); } else if (loption == "encryption_config") { if (option.second.size() != 1) { throw BinderException("Parquet encryption_config cannot be empty!"); @@ -453,9 +497,9 @@ class ParquetScanFunction { // for remote files we just avoid reading stats entirely return nullptr; } - ParquetReader reader(context, bind_data.parquet_options, metadata); // get and merge stats for file - auto file_stats = reader.ReadStatistics(bind_data.names[column_index]); + auto file_stats = ParquetReader::ReadStatistics(context, bind_data.parquet_options, metadata, + bind_data.names[column_index]); if (!file_stats) { return nullptr; } @@ -510,12 +554,31 @@ class ParquetScanFunction { if (return_types.size() != result->types.size()) { auto file_string = bound_on_first_file ? result->file_list->GetFirstFile() : StringUtil::Join(result->file_list->GetPaths(), ","); - throw std::runtime_error(StringUtil::Format( - "Failed to read file(s) \"%s\" - column count mismatch: expected %d columns but found %d", - file_string, return_types.size(), result->types.size())); + string extended_error; + extended_error = "Table schema: "; + for (idx_t col_idx = 0; col_idx < return_types.size(); col_idx++) { + if (col_idx > 0) { + extended_error += ", "; + } + extended_error += names[col_idx] + " " + return_types[col_idx].ToString(); + } + extended_error += "\nParquet schema: "; + for (idx_t col_idx = 0; col_idx < result->types.size(); col_idx++) { + if (col_idx > 0) { + extended_error += ", "; + } + extended_error += result->names[col_idx] + " " + result->types[col_idx].ToString(); + } + extended_error += "\n\nPossible solutions:"; + extended_error += "\n* Manually specify which columns to insert using \"INSERT INTO tbl SELECT ... " + "FROM read_parquet(...)\""; + throw ConversionException( + "Failed to read file(s) \"%s\" - column count mismatch: expected %d columns but found %d\n%s", + file_string, return_types.size(), result->types.size(), extended_error); } // expected types - overwrite the types we want to read instead result->types = return_types; + result->table_columns = names; } result->parquet_options = parquet_options; return std::move(result); @@ -527,6 +590,9 @@ class ParquetScanFunction { ParquetOptions parquet_options(context); for (auto &kv : input.named_parameters) { + if (kv.second.IsNull()) { + throw BinderException("Cannot use NULL as function argument"); + } auto loption = StringUtil::Lower(kv.first); if (multi_file_reader->ParseOption(kv.first, kv.second, parquet_options.file_options, context)) { continue; @@ -535,6 +601,8 @@ class ParquetScanFunction { parquet_options.binary_as_string = BooleanValue::Get(kv.second); } else if (loption == "file_row_number") { parquet_options.file_row_number = BooleanValue::Get(kv.second); + } else if (loption == "debug_use_openssl") { + parquet_options.debug_use_openssl = BooleanValue::Get(kv.second); } else if (loption == "schema") { // Argument is a map that defines the schema const auto &schema_value = kv.second; @@ -565,16 +633,16 @@ class ParquetScanFunction { auto &bind_data = bind_data_p->Cast(); auto &gstate = global_state->Cast(); - auto total_count = bind_data.file_list->GetTotalFileCount(); + auto total_count = gstate.file_list.GetTotalFileCount(); if (total_count == 0) { return 100.0; } if (bind_data.initial_file_cardinality == 0) { - return (100.0 * (gstate.file_index + 1)) / total_count; + return (100.0 * (static_cast(gstate.file_index) + 1.0)) / static_cast(total_count); } - auto percentage = MinValue( - 100.0, (bind_data.chunk_count * STANDARD_VECTOR_SIZE * 100.0 / bind_data.initial_file_cardinality)); - return (percentage + 100.0 * gstate.file_index) / total_count; + auto percentage = MinValue(100.0, (static_cast(bind_data.chunk_count) * STANDARD_VECTOR_SIZE * + 100.0 / static_cast(bind_data.initial_file_cardinality))); + return (percentage + 100.0 * static_cast(gstate.file_index)) / static_cast(total_count); } static unique_ptr @@ -595,16 +663,37 @@ class ParquetScanFunction { return std::move(result); } + static unique_ptr ParquetDynamicFilterPushdown(ClientContext &context, + const ParquetReadBindData &data, + const vector &column_ids, + optional_ptr filters) { + if (!filters) { + return nullptr; + } + auto new_list = data.multi_file_reader->DynamicFilterPushdown( + context, *data.file_list, data.parquet_options.file_options, data.names, data.types, column_ids, *filters); + return new_list; + } + static unique_ptr ParquetScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { auto &bind_data = input.bind_data->CastNoConst(); - auto result = make_uniq(); - bind_data.file_list->InitializeScan(result->file_list_scan); + unique_ptr result; + + // before instantiating a scan trigger a dynamic filter pushdown if possible + auto new_list = ParquetDynamicFilterPushdown(context, bind_data, input.column_ids, input.filters); + if (new_list) { + result = make_uniq(std::move(new_list)); + } else { + result = make_uniq(*bind_data.file_list); + } + auto &file_list = result->file_list; + file_list.InitializeScan(result->file_list_scan); result->multi_file_reader_state = bind_data.multi_file_reader->InitializeGlobalState( - context, bind_data.parquet_options.file_options, bind_data.reader_bind, *bind_data.file_list, - bind_data.types, bind_data.names, input.column_ids); - if (bind_data.file_list->IsEmpty()) { + context, bind_data.parquet_options.file_options, bind_data.reader_bind, file_list, bind_data.types, + bind_data.names, input.column_ids); + if (file_list.IsEmpty()) { result->readers = {}; } else if (!bind_data.union_readers.empty()) { // TODO: confirm we are not changing behaviour by modifying the order here? @@ -612,33 +701,38 @@ class ParquetScanFunction { if (!reader) { break; } - result->readers.push_back(ParquetFileReaderData(std::move(reader))); + result->readers.push_back(make_uniq(std::move(reader))); } - if (result->readers.size() != bind_data.file_list->GetTotalFileCount()) { + if (result->readers.size() != file_list.GetTotalFileCount()) { // This case happens with recursive CTEs: the first execution the readers have already // been moved out of the bind data. // FIXME: clean up this process and make it more explicit result->readers = {}; } } else if (bind_data.initial_reader) { - // Ensure the initial reader was actually constructed from the first file - if (bind_data.initial_reader->file_name != bind_data.file_list->GetFirstFile()) { - throw InternalException("First file from list ('%s') does not match first reader ('%s')", - bind_data.initial_reader->file_name, bind_data.file_list->GetFirstFile()); + // we can only use the initial reader if it was constructed from the first file + if (bind_data.initial_reader->file_name == file_list.GetFirstFile()) { + result->readers.push_back(make_uniq(std::move(bind_data.initial_reader))); } - result->readers.emplace_back(std::move(bind_data.initial_reader)); } // Ensure all readers are initialized and FileListScan is sync with readers list for (auto &reader_data : result->readers) { string file_name; idx_t file_idx = result->file_list_scan.current_file_idx; - bind_data.file_list->Scan(result->file_list_scan, file_name); - if (file_name != reader_data.reader->file_name) { - throw InternalException("Mismatch in filename order and reader order in parquet scan"); + file_list.Scan(result->file_list_scan, file_name); + if (reader_data->union_data) { + if (file_name != reader_data->union_data->GetFileName()) { + throw InternalException("Mismatch in filename order and union reader order in parquet scan"); + } + } else { + D_ASSERT(reader_data->reader); + if (file_name != reader_data->reader->file_name) { + throw InternalException("Mismatch in filename order and reader order in parquet scan"); + } + InitializeParquetReader(*reader_data->reader, bind_data, input.column_ids, input.filters, context, + file_idx, result->multi_file_reader_state); } - InitializeParquetReader(*reader_data.reader, bind_data, input.column_ids, input.filters, context, file_idx, - result->multi_file_reader_state); } result->column_ids = input.column_ids; @@ -692,6 +786,9 @@ class ParquetScanFunction { serializer.WriteProperty(101, "types", bind_data.types); serializer.WriteProperty(102, "names", bind_data.names); serializer.WriteProperty(103, "parquet_options", bind_data.parquet_options); + if (serializer.ShouldSerialize(3)) { + serializer.WriteProperty(104, "table_columns", bind_data.table_columns); + } } static unique_ptr ParquetScanDeserialize(Deserializer &deserializer, TableFunction &function) { @@ -700,6 +797,8 @@ class ParquetScanFunction { auto types = deserializer.ReadProperty>(101, "types"); auto names = deserializer.ReadProperty>(102, "names"); auto parquet_options = deserializer.ReadProperty(103, "parquet_options"); + auto table_columns = + deserializer.ReadPropertyWithExplicitDefault>(104, "table_columns", vector {}); vector file_path; for (auto &path : files) { @@ -709,8 +808,10 @@ class ParquetScanFunction { auto multi_file_reader = MultiFileReader::Create(function); auto file_list = multi_file_reader->CreateFileList(context, Value::LIST(LogicalType::VARCHAR, file_path), FileGlobOptions::DISALLOW_EMPTY); - return ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), types, names, - parquet_options); + auto bind_data = ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), types, + names, parquet_options); + bind_data->Cast().table_columns = std::move(table_columns); + return bind_data; } static void ParquetScanImplementation(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { @@ -746,6 +847,12 @@ class ParquetScanFunction { static unique_ptr ParquetCardinality(ClientContext &context, const FunctionData *bind_data) { auto &data = bind_data->Cast(); + + auto file_list_cardinality_estimate = data.file_list->GetCardinality(context); + if (file_list_cardinality_estimate) { + return file_list_cardinality_estimate; + } + return make_uniq(data.initial_file_cardinality * data.file_list->GetTotalFileCount()); } @@ -761,14 +868,14 @@ class ParquetScanFunction { // Queries the metadataprovider for another file to scan, updating the files/reader lists in the process. // Returns true if resized - static bool ResizeFiles(const ParquetReadBindData &bind_data, ParquetReadGlobalState ¶llel_state) { + static bool ResizeFiles(ParquetReadGlobalState ¶llel_state) { string scanned_file; - if (!bind_data.file_list->Scan(parallel_state.file_list_scan, scanned_file)) { + if (!parallel_state.file_list.Scan(parallel_state.file_list_scan, scanned_file)) { return false; } // Push the file in the reader data, to be opened later - parallel_state.readers.emplace_back(scanned_file); + parallel_state.readers.push_back(make_uniq(scanned_file)); return true; } @@ -784,17 +891,17 @@ class ParquetScanFunction { return false; } - if (parallel_state.file_index >= parallel_state.readers.size() && !ResizeFiles(bind_data, parallel_state)) { + if (parallel_state.file_index >= parallel_state.readers.size() && !ResizeFiles(parallel_state)) { return false; } - auto ¤t_reader_data = parallel_state.readers[parallel_state.file_index]; + auto ¤t_reader_data = *parallel_state.readers[parallel_state.file_index]; if (current_reader_data.file_state == ParquetFileState::OPEN) { if (parallel_state.row_group_index < current_reader_data.reader->NumRowGroups()) { // The current reader has rowgroups left to be scanned scan_data.reader = current_reader_data.reader; vector group_indexes {parallel_state.row_group_index}; - scan_data.reader->InitializeScan(scan_data.scan_state, group_indexes); + scan_data.reader->InitializeScan(context, scan_data.scan_state, group_indexes); scan_data.batch_index = parallel_state.batch_index++; scan_data.file_index = parallel_state.file_index; parallel_state.row_group_index++; @@ -817,7 +924,7 @@ class ParquetScanFunction { } // Check if the current file is being opened, in that case we need to wait for it. - if (parallel_state.readers[parallel_state.file_index].file_state == ParquetFileState::OPENING) { + if (current_reader_data.file_state == ParquetFileState::OPENING) { WaitForFile(parallel_state.file_index, parallel_state, parallel_lock); } } @@ -827,8 +934,9 @@ class ParquetScanFunction { vector> &filters) { auto &data = bind_data_p->Cast(); + MultiFilePushdownInfo info(get); auto new_list = data.multi_file_reader->ComplexFilterPushdown(context, *data.file_list, - data.parquet_options.file_options, get, filters); + data.parquet_options.file_options, info, filters); if (new_list) { data.file_list = std::move(new_list); @@ -840,9 +948,8 @@ class ParquetScanFunction { static void WaitForFile(idx_t file_index, ParquetReadGlobalState ¶llel_state, unique_lock ¶llel_lock) { while (true) { - // Get pointer to file mutex before unlocking - auto &file_mutex = *parallel_state.readers[file_index].file_mutex; + auto &file_mutex = *parallel_state.readers[file_index]->file_mutex; // To get the file lock, we first need to release the parallel_lock to prevent deadlocking. Note that this // requires getting the ref to the file mutex pointer with the lock stil held: readers get be resized @@ -855,7 +962,7 @@ class ParquetScanFunction { // - the thread opening the file has failed // - the file was somehow scanned till the end while we were waiting if (parallel_state.file_index >= parallel_state.readers.size() || - parallel_state.readers[parallel_state.file_index].file_state != ParquetFileState::OPENING || + parallel_state.readers[parallel_state.file_index]->file_state != ParquetFileState::OPENING || parallel_state.error_opening_file) { return; } @@ -866,14 +973,17 @@ class ParquetScanFunction { static bool TryOpenNextFile(ClientContext &context, const ParquetReadBindData &bind_data, ParquetReadLocalState &scan_data, ParquetReadGlobalState ¶llel_state, unique_lock ¶llel_lock) { - const auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - const auto file_index_limit = - MinValue(parallel_state.file_index + num_threads, parallel_state.readers.size()); + parallel_state.file_index + TaskScheduler::GetScheduler(context).NumberOfThreads(); for (idx_t i = parallel_state.file_index; i < file_index_limit; i++) { - if (parallel_state.readers[i].file_state == ParquetFileState::UNOPENED) { - auto ¤t_reader_data = parallel_state.readers[i]; + // We check if we can resize files in this loop too otherwise we will only ever open 1 file ahead + if (i >= parallel_state.readers.size() && !ResizeFiles(parallel_state)) { + return false; + } + + auto ¤t_reader_data = *parallel_state.readers[i]; + if (current_reader_data.file_state == ParquetFileState::UNOPENED) { current_reader_data.file_state = ParquetFileState::OPENING; auto pq_options = bind_data.parquet_options; @@ -887,7 +997,14 @@ class ParquetScanFunction { shared_ptr reader; try { - reader = make_shared_ptr(context, current_reader_data.file_to_be_opened, pq_options); + if (current_reader_data.union_data) { + auto &union_data = *current_reader_data.union_data; + reader = make_shared_ptr(context, union_data.file_name, union_data.options, + union_data.metadata); + } else { + reader = + make_shared_ptr(context, current_reader_data.file_to_be_opened, pq_options); + } InitializeParquetReader(*reader, bind_data, parallel_state.column_ids, parallel_state.filters, context, i, parallel_state.multi_file_reader_state); } catch (...) { @@ -898,7 +1015,7 @@ class ParquetScanFunction { // Now re-lock the state and add the reader parallel_lock.lock(); - current_reader_data.reader = reader; + current_reader_data.reader = std::move(reader); current_reader_data.file_state = ParquetFileState::OPEN; return true; @@ -962,7 +1079,7 @@ static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const ve D_ASSERT(names.size() == sql_types.size()); for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { const auto &col_name = names[col_idx]; - auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(field_id++))); + auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); D_ASSERT(inserted.second); const auto &col_type = sql_types[col_idx]; @@ -1006,8 +1123,10 @@ static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, } names += name.first; } - throw BinderException("Column name \"%s\" specified in FIELD_IDS not found. Available column names: [%s]", - col_name, names); + throw BinderException( + "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available column names: [%s]", + col_name, names); } D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys @@ -1038,7 +1157,7 @@ static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, if (!unique_field_ids.insert(field_id_int).second) { throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); } - field_id = FieldID(field_id_int); + field_id = FieldID(UnsafeNumericCast(field_id_int)); } auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); D_ASSERT(inserted.second); @@ -1078,6 +1197,8 @@ unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBi bind_data->row_group_size_bytes = option.second[0].GetValue(); } row_group_size_bytes_set = true; + } else if (loption == "row_groups_per_file") { + bind_data->row_groups_per_file = option.second[0].GetValue(); } else if (loption == "compression" || loption == "codec") { const auto roption = StringUtil::Lower(option.second[0].ToString()); if (roption == "uncompressed") { @@ -1088,12 +1209,14 @@ unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBi bind_data->codec = duckdb_parquet::format::CompressionCodec::GZIP; } else if (roption == "zstd") { bind_data->codec = duckdb_parquet::format::CompressionCodec::ZSTD; + } else if (roption == "brotli") { + bind_data->codec = duckdb_parquet::format::CompressionCodec::BROTLI; } else if (roption == "lz4" || roption == "lz4_raw") { /* LZ4 is technically another compression scheme, but deprecated and arrow also uses them * interchangeably */ bind_data->codec = duckdb_parquet::format::CompressionCodec::LZ4_RAW; } else { - throw BinderException("Expected %s argument to be either [uncompressed, snappy, gzip or zstd]", + throw BinderException("Expected %s argument to be either [uncompressed, brotli, gzip, snappy, or zstd]", loption); } } else if (loption == "field_ids") { @@ -1142,6 +1265,15 @@ unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBi "dictionary compression"); } bind_data->dictionary_compression_ratio_threshold = val; + } else if (loption == "debug_use_openssl") { + auto val = StringUtil::Lower(option.second[0].GetValue()); + if (val == "false") { + bind_data->debug_use_openssl = false; + } else if (val == "true") { + bind_data->debug_use_openssl = true; + } else { + throw BinderException("Expected debug_use_openssl to be a BOOLEAN"); + } } else if (loption == "compression_level") { bind_data->compression_level = option.second[0].GetValue(); } else { @@ -1169,10 +1301,11 @@ unique_ptr ParquetWriteInitializeGlobal(ClientContext &conte auto &parquet_bind = bind_data.Cast(); auto &fs = FileSystem::GetFileSystem(context); - global_state->writer = make_uniq( - fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, parquet_bind.codec, - parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, parquet_bind.encryption_config, - parquet_bind.dictionary_compression_ratio_threshold, parquet_bind.compression_level); + global_state->writer = + make_uniq(context, fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, + parquet_bind.codec, parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, + parquet_bind.encryption_config, parquet_bind.dictionary_compression_ratio_threshold, + parquet_bind.compression_level, parquet_bind.debug_use_openssl); return std::move(global_state); } @@ -1185,8 +1318,8 @@ void ParquetWriteSink(ExecutionContext &context, FunctionData &bind_data_p, Glob // append data to the local (buffered) chunk collection local_state.buffer.Append(local_state.append_state, input); - if (local_state.buffer.Count() > bind_data.row_group_size || - local_state.buffer.SizeInBytes() > bind_data.row_group_size_bytes) { + if (local_state.buffer.Count() >= bind_data.row_group_size || + local_state.buffer.SizeInBytes() >= bind_data.row_group_size_bytes) { // if the chunk collection exceeds a certain size (rows/bytes) we flush it to the parquet file local_state.append_state.current_chunk_state.handles.clear(); global_state.writer->Flush(local_state.buffer); @@ -1294,6 +1427,8 @@ static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bin serializer.WriteProperty(108, "dictionary_compression_ratio_threshold", bind_data.dictionary_compression_ratio_threshold); serializer.WritePropertyWithDefault(109, "compression_level", bind_data.compression_level); + serializer.WriteProperty(110, "row_groups_per_file", bind_data.row_groups_per_file); + serializer.WriteProperty(111, "debug_use_openssl", bind_data.debug_use_openssl); } static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { @@ -1305,11 +1440,14 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize data->row_group_size_bytes = deserializer.ReadProperty(104, "row_group_size_bytes"); data->kv_metadata = deserializer.ReadProperty>>(105, "kv_metadata"); data->field_ids = deserializer.ReadProperty(106, "field_ids"); - deserializer.ReadPropertyWithDefault>(107, "encryption_config", - data->encryption_config, nullptr); - deserializer.ReadPropertyWithDefault(108, "dictionary_compression_ratio_threshold", - data->dictionary_compression_ratio_threshold, 1.0); + deserializer.ReadPropertyWithExplicitDefault>(107, "encryption_config", + data->encryption_config, nullptr); + deserializer.ReadPropertyWithExplicitDefault(108, "dictionary_compression_ratio_threshold", + data->dictionary_compression_ratio_threshold, 1.0); deserializer.ReadPropertyWithDefault(109, "compression_level", data->compression_level); + data->row_groups_per_file = + deserializer.ReadPropertyWithExplicitDefault(110, "row_groups_per_file", optional_idx::Invalid()); + data->debug_use_openssl = deserializer.ReadPropertyWithExplicitDefault(111, "debug_use_openssl", true); return std::move(data); } // LCOV_EXCL_STOP @@ -1361,11 +1499,25 @@ idx_t ParquetWriteDesiredBatchSize(ClientContext &context, FunctionData &bind_da } //===--------------------------------------------------------------------===// -// Current File Size +// File rotation //===--------------------------------------------------------------------===// -idx_t ParquetWriteFileSize(GlobalFunctionData &gstate) { +bool ParquetWriteRotateFiles(FunctionData &bind_data_p, const optional_idx &file_size_bytes) { + auto &bind_data = bind_data_p.Cast(); + return file_size_bytes.IsValid() || bind_data.row_groups_per_file.IsValid(); +} + +bool ParquetWriteRotateNextFile(GlobalFunctionData &gstate, FunctionData &bind_data_p, + const optional_idx &file_size_bytes) { auto &global_state = gstate.Cast(); - return global_state.writer->FileSize(); + auto &bind_data = bind_data_p.Cast(); + if (file_size_bytes.IsValid() && global_state.writer->FileSize() > file_size_bytes.GetIndex()) { + return true; + } + if (bind_data.row_groups_per_file.IsValid() && + global_state.writer->NumberOfRowGroups() >= bind_data.row_groups_per_file.GetIndex()) { + return true; + } + return false; } //===--------------------------------------------------------------------===// @@ -1373,7 +1525,7 @@ idx_t ParquetWriteFileSize(GlobalFunctionData &gstate) { //===--------------------------------------------------------------------===// unique_ptr ParquetScanReplacement(ClientContext &context, ReplacementScanInput &input, optional_ptr data) { - auto &table_name = input.table_name; + auto table_name = ReplacementScan::GetFullPath(input); if (!ReplacementScan::CanReplace(table_name, {"parquet"})) { return nullptr; } @@ -1390,6 +1542,86 @@ unique_ptr ParquetScanReplacement(ClientContext &context, ReplacementS return std::move(table_function); } +//===--------------------------------------------------------------------===// +// Select +//===--------------------------------------------------------------------===// +// Helper predicates for ParquetWriteSelect +static bool IsTypeNotSupported(const LogicalType &type) { + if (type.IsNested()) { + return false; + } + return !ParquetWriter::TryGetParquetType(type); +} + +static bool IsTypeLossy(const LogicalType &type) { + return type.id() == LogicalTypeId::HUGEINT || type.id() == LogicalTypeId::UHUGEINT; +} + +static vector> ParquetWriteSelect(CopyToSelectInput &input) { + + auto &context = input.context; + + vector> result; + + bool any_change = false; + + for (auto &expr : input.select_list) { + + const auto &type = expr->return_type; + const auto &name = expr->alias; + + // Spatial types need to be encoded into WKB when writing GeoParquet. + // But dont perform this conversion if this is a EXPORT DATABASE statement + if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::BLOB && type.HasAlias() && + type.GetAlias() == "GEOMETRY") { + + LogicalType wkb_blob_type(LogicalTypeId::BLOB); + wkb_blob_type.SetAlias("WKB_BLOB"); + + auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), wkb_blob_type, false); + cast_expr->alias = name; + result.push_back(std::move(cast_expr)); + any_change = true; + } + // If this is an EXPORT DATABASE statement, we dont want to write "lossy" types, instead cast them to VARCHAR + else if (input.copy_to_type == CopyToType::EXPORT_DATABASE && TypeVisitor::Contains(type, IsTypeLossy)) { + // Replace all lossy types with VARCHAR + auto new_type = TypeVisitor::VisitReplace( + type, [](const LogicalType &ty) -> LogicalType { return IsTypeLossy(ty) ? LogicalType::VARCHAR : ty; }); + + // Cast the column to the new type + auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), new_type, false); + cast_expr->alias = name; + result.push_back(std::move(cast_expr)); + any_change = true; + } + // Else look if there is any unsupported type + else if (TypeVisitor::Contains(type, IsTypeNotSupported)) { + // If there is at least one unsupported type, replace all unsupported types with varchar + // and perform a CAST + auto new_type = TypeVisitor::VisitReplace(type, [](const LogicalType &ty) -> LogicalType { + return IsTypeNotSupported(ty) ? LogicalType::VARCHAR : ty; + }); + + auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), new_type, false); + cast_expr->alias = name; + result.push_back(std::move(cast_expr)); + any_change = true; + } + // Otherwise, just reference the input column + else { + result.push_back(std::move(expr)); + } + } + + // If any change was made, return the new expressions + // otherwise, return an empty vector to indicate no change and avoid pushing another projection on to the plan + if (any_change) { + return result; + } + return {}; +} + void ParquetExtension::Load(DuckDB &db) { auto &db_instance = *db.instance; auto &fs = db.GetFileSystem(); @@ -1418,6 +1650,7 @@ void ParquetExtension::Load(DuckDB &db) { ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(file_meta_fun)); CopyFunction function("parquet"); + function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; function.copy_to_initialize_global = ParquetWriteInitializeGlobal; function.copy_to_initialize_local = ParquetWriteInitializeLocal; @@ -1430,10 +1663,10 @@ void ParquetExtension::Load(DuckDB &db) { function.prepare_batch = ParquetWritePrepareBatch; function.flush_batch = ParquetWriteFlushBatch; function.desired_batch_size = ParquetWriteDesiredBatchSize; - function.file_size_bytes = ParquetWriteFileSize; + function.rotate_files = ParquetWriteRotateFiles; + function.rotate_next_file = ParquetWriteRotateNextFile; function.serialize = ParquetCopySerialize; function.deserialize = ParquetCopyDeserialize; - function.supports_type = ParquetWriter::TypeIsSupported; function.extension = "parquet"; ExtensionUtil::RegisterFunction(db_instance, function); diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 72b16fae..b1d70a54 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -211,19 +211,19 @@ void ParquetMetaDataOperatorData::LoadRowGroupMetadata(ClientContext &context, c current_chunk.SetValue(0, count, file_path); // row_group_id, LogicalType::BIGINT - current_chunk.SetValue(1, count, Value::BIGINT(row_group_idx)); + current_chunk.SetValue(1, count, Value::BIGINT(UnsafeNumericCast(row_group_idx))); // row_group_num_rows, LogicalType::BIGINT current_chunk.SetValue(2, count, Value::BIGINT(row_group.num_rows)); // row_group_num_columns, LogicalType::BIGINT - current_chunk.SetValue(3, count, Value::BIGINT(row_group.columns.size())); + current_chunk.SetValue(3, count, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); // row_group_bytes, LogicalType::BIGINT current_chunk.SetValue(4, count, Value::BIGINT(row_group.total_byte_size)); // column_id, LogicalType::BIGINT - current_chunk.SetValue(5, count, Value::BIGINT(col_idx)); + current_chunk.SetValue(5, count, Value::BIGINT(UnsafeNumericCast(col_idx))); // file_offset, LogicalType::BIGINT current_chunk.SetValue(6, count, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); @@ -545,7 +545,7 @@ void ParquetMetaDataOperatorData::LoadFileMetaData(ClientContext &context, const // num_rows current_chunk.SetValue(2, 0, Value::BIGINT(meta_data->num_rows)); // num_row_groups - current_chunk.SetValue(3, 0, Value::BIGINT(meta_data->row_groups.size())); + current_chunk.SetValue(3, 0, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); // format_version current_chunk.SetValue(4, 0, Value::BIGINT(meta_data->version)); // encryption_algorithm diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index 75017ea3..0508d254 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -5,27 +5,29 @@ #include "cast_column_reader.hpp" #include "column_reader.hpp" #include "duckdb.hpp" +#include "expression_column_reader.hpp" +#include "geo_parquet.hpp" #include "list_column_reader.hpp" #include "parquet_crypto.hpp" #include "parquet_file_metadata_cache.hpp" #include "parquet_statistics.hpp" #include "parquet_timestamp.hpp" +#include "mbedtls_wrapper.hpp" #include "row_number_column_reader.hpp" #include "string_column_reader.hpp" #include "struct_column_reader.hpp" #include "templated_column_reader.hpp" #include "thrift_tools.hpp" +#include "duckdb/main/config.hpp" + #ifndef DUCKDB_AMALGAMATION +#include "duckdb/common/encryption_state.hpp" #include "duckdb/common/file_system.hpp" -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/pair.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/hive_partitioning.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/null_filter.hpp" #include "duckdb/planner/filter/struct_filter.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/object_cache.hpp" @@ -55,8 +57,9 @@ CreateThriftFileProtocol(Allocator &allocator, FileHandle &file_handle, bool pre } static shared_ptr -LoadMetadata(Allocator &allocator, FileHandle &file_handle, - const shared_ptr &encryption_config) { +LoadMetadata(ClientContext &context, Allocator &allocator, FileHandle &file_handle, + const shared_ptr &encryption_config, + const EncryptionUtil &encryption_util) { auto current_time = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); auto file_proto = CreateThriftFileProtocol(allocator, file_handle, false); @@ -108,12 +111,15 @@ LoadMetadata(Allocator &allocator, FileHandle &file_handle, throw InvalidInputException("File '%s' is encrypted with AES_GCM_CTR_V1, but only AES_GCM_V1 is supported", file_handle.path); } - ParquetCrypto::Read(*metadata, *file_proto, encryption_config->GetFooterKey()); + ParquetCrypto::Read(*metadata, *file_proto, encryption_config->GetFooterKey(), encryption_util); } else { metadata->read(file_proto.get()); } - return make_shared_ptr(std::move(metadata), current_time); + // Try to read the GeoParquet metadata (if present) + auto geo_metadata = GeoParquetFileMetadata::TryRead(*metadata, context); + + return make_shared_ptr(std::move(metadata), current_time, std::move(geo_metadata)); } LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string) { @@ -129,6 +135,8 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool bi } else if (s_ele.logicalType.__isset.TIMESTAMP) { if (s_ele.logicalType.TIMESTAMP.isAdjustedToUTC) { return LogicalType::TIMESTAMP_TZ; + } else if (s_ele.logicalType.TIMESTAMP.unit.__isset.NANOS) { + return LogicalType::TIMESTAMP_NS; } return LogicalType::TIMESTAMP; } else if (s_ele.logicalType.__isset.TIME) { @@ -242,7 +250,7 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool bi case ConvertedType::INTERVAL: return LogicalType::INTERVAL; case ConvertedType::JSON: - return LogicalType::VARCHAR; + return LogicalType::JSON(); case ConvertedType::NULL_TYPE: return LogicalTypeId::SQLNULL; case ConvertedType::MAP: @@ -284,8 +292,9 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele) { return DeriveLogicalType(s_ele, parquet_options.binary_as_string); } -unique_ptr ParquetReader::CreateReaderRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, - idx_t &next_schema_idx, idx_t &next_file_idx) { +unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &context, idx_t depth, idx_t max_define, + idx_t max_repeat, idx_t &next_schema_idx, + idx_t &next_file_idx) { auto file_meta_data = GetFileMetadata(); D_ASSERT(file_meta_data); D_ASSERT(next_schema_idx < file_meta_data->schema.size()); @@ -302,6 +311,16 @@ unique_ptr ParquetReader::CreateReaderRecursive(idx_t depth, idx_t if (repetition_type == FieldRepetitionType::REPEATED) { max_repeat++; } + + // Check for geoparquet spatial types + if (depth == 1) { + // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata + if (metadata->geo_metadata && metadata->geo_metadata->IsGeometryColumn(s_ele.name)) { + return metadata->geo_metadata->CreateColumnReader(*this, DeriveLogicalType(s_ele), s_ele, next_file_idx++, + max_define, max_repeat, context); + } + } + if (s_ele.__isset.num_children && s_ele.num_children > 0) { // inner node child_list_t child_types; vector> child_readers; @@ -313,12 +332,27 @@ unique_ptr ParquetReader::CreateReaderRecursive(idx_t depth, idx_t auto &child_ele = file_meta_data->schema[next_schema_idx]; auto child_reader = - CreateReaderRecursive(depth + 1, max_define, max_repeat, next_schema_idx, next_file_idx); + CreateReaderRecursive(context, depth + 1, max_define, max_repeat, next_schema_idx, next_file_idx); child_types.push_back(make_pair(child_ele.name, child_reader->Type())); child_readers.push_back(std::move(child_reader)); c_idx++; } + // rename child type entries if there are case-insensitive duplicates by appending _1, _2 etc. + // behavior consistent with CSV reader fwiw + case_insensitive_map_t name_collision_count; + // get header names from CSV + for (auto &child_type : child_types) { + auto col_name = child_type.first; + // avoid duplicate header names + while (name_collision_count.find(col_name) != name_collision_count.end()) { + name_collision_count[col_name] += 1; + col_name = col_name + "_" + to_string(name_collision_count[col_name]); + } + child_type.first = col_name; + name_collision_count[col_name] = 0; + } + D_ASSERT(!child_types.empty()); unique_ptr result; LogicalType result_type; @@ -387,7 +421,7 @@ unique_ptr ParquetReader::CreateReaderRecursive(idx_t depth, idx_t } // TODO we don't need readers for columns we are not going to read ay -unique_ptr ParquetReader::CreateReader() { +unique_ptr ParquetReader::CreateReader(ClientContext &context) { auto file_meta_data = GetFileMetadata(); idx_t next_schema_idx = 0; idx_t next_file_idx = 0; @@ -398,7 +432,7 @@ unique_ptr ParquetReader::CreateReader() { if (file_meta_data->schema[0].num_children == 0) { throw IOException("Parquet reader: root schema element has no children"); } - auto ret = CreateReaderRecursive(0, 0, 0, next_schema_idx, next_file_idx); + auto ret = CreateReaderRecursive(context, 0, 0, 0, next_schema_idx, next_file_idx); if (ret->Type().id() != LogicalTypeId::STRUCT) { throw InvalidInputException("Root element of Parquet file must be a struct"); } @@ -425,7 +459,7 @@ unique_ptr ParquetReader::CreateReader() { return ret; } -void ParquetReader::InitializeSchema() { +void ParquetReader::InitializeSchema(ClientContext &context) { auto file_meta_data = GetFileMetadata(); if (file_meta_data->__isset.encryption_algorithm) { @@ -438,7 +472,7 @@ void ParquetReader::InitializeSchema() { if (file_meta_data->schema.size() < 2) { throw FormatException("Need at least one non-root column in the file"); } - root_reader = CreateReader(); + root_reader = CreateReader(context); auto &root_type = root_reader->Type(); auto &child_types = StructType::GetChildTypes(root_type); D_ASSERT(root_type.id() == LogicalTypeId::STRUCT); @@ -484,7 +518,8 @@ ParquetColumnDefinition ParquetColumnDefinition::FromSchemaValue(ClientContext & return result; } -ParquetReader::ParquetReader(ClientContext &context_p, string file_name_p, ParquetOptions parquet_options_p) +ParquetReader::ParquetReader(ClientContext &context_p, string file_name_p, ParquetOptions parquet_options_p, + shared_ptr metadata_p) : fs(FileSystem::GetFileSystem(context_p)), allocator(BufferAllocator::Get(context_p)), parquet_options(std::move(parquet_options_p)) { file_name = std::move(file_name_p); @@ -494,27 +529,45 @@ ParquetReader::ParquetReader(ClientContext &context_p, string file_name_p, Parqu "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " "metadata is located at the end of the file. Write the stream to disk first and read from there instead."); } + + // set pointer to factory method for AES state + auto &config = DBConfig::GetConfig(context_p); + if (config.encryption_util && parquet_options.debug_use_openssl) { + encryption_util = config.encryption_util; + } else { + encryption_util = make_shared_ptr(); + } + // If object cached is disabled // or if this file has cached metadata // or if the cached version already expired - if (!ObjectCache::ObjectCacheEnabled(context_p)) { - metadata = LoadMetadata(allocator, *file_handle, parquet_options.encryption_config); - } else { - auto last_modify_time = fs.GetLastModifiedTime(*file_handle); - metadata = ObjectCache::GetObjectCache(context_p).Get(file_name); - if (!metadata || (last_modify_time + 10 >= metadata->read_time)) { - metadata = LoadMetadata(allocator, *file_handle, parquet_options.encryption_config); - ObjectCache::GetObjectCache(context_p).Put(file_name, metadata); + if (!metadata_p) { + if (!ObjectCache::ObjectCacheEnabled(context_p)) { + metadata = + LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, *encryption_util); + } else { + auto last_modify_time = fs.GetLastModifiedTime(*file_handle); + metadata = ObjectCache::GetObjectCache(context_p).Get(file_name); + if (!metadata || (last_modify_time + 10 >= metadata->read_time)) { + metadata = LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, + *encryption_util); + ObjectCache::GetObjectCache(context_p).Put(file_name, metadata); + } } + } else { + metadata = std::move(metadata_p); } - InitializeSchema(); + InitializeSchema(context_p); +} + +ParquetUnionData::~ParquetUnionData() { } ParquetReader::ParquetReader(ClientContext &context_p, ParquetOptions parquet_options_p, shared_ptr metadata_p) : fs(FileSystem::GetFileSystem(context_p)), allocator(BufferAllocator::Get(context_p)), metadata(std::move(metadata_p)), parquet_options(std::move(parquet_options_p)) { - InitializeSchema(); + InitializeSchema(context_p); } ParquetReader::~ParquetReader() { @@ -556,9 +609,16 @@ unique_ptr ParquetReader::ReadStatistics(const string &name) { return column_stats; } +unique_ptr ParquetReader::ReadStatistics(ClientContext &context, ParquetOptions parquet_options, + shared_ptr metadata, + const string &name) { + ParquetReader reader(context, std::move(parquet_options), std::move(metadata)); + return reader.ReadStatistics(name); +} + uint32_t ParquetReader::Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot) { if (parquet_options.encryption_config) { - return ParquetCrypto::Read(object, iprot, parquet_options.encryption_config->GetFooterKey()); + return ParquetCrypto::Read(object, iprot, parquet_options.encryption_config->GetFooterKey(), *encryption_util); } else { return object.read(&iprot); } @@ -567,7 +627,8 @@ uint32_t ParquetReader::Read(duckdb_apache::thrift::TBase &object, TProtocol &ip uint32_t ParquetReader::ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size) { if (parquet_options.encryption_config) { - return ParquetCrypto::ReadData(iprot, buffer, buffer_size, parquet_options.encryption_config->GetFooterKey()); + return ParquetCrypto::ReadData(iprot, buffer, buffer_size, parquet_options.encryption_config->GetFooterKey(), + *encryption_util); } else { return iprot.getTransport()->read(buffer, buffer_size); } @@ -642,6 +703,20 @@ idx_t ParquetReader::GetGroupOffset(ParquetReaderScanState &state) { return min_offset; } +static FilterPropagateResult CheckParquetStringFilter(BaseStatistics &stats, const Statistics &pq_col_stats, + TableFilter &filter) { + if (filter.filter_type == TableFilterType::CONSTANT_COMPARISON) { + auto &constant_filter = filter.Cast(); + auto &min_value = pq_col_stats.min_value; + auto &max_value = pq_col_stats.max_value; + return StringStats::CheckZonemap(const_data_ptr_cast(min_value.c_str()), min_value.size(), + const_data_ptr_cast(max_value.c_str()), max_value.size(), + constant_filter.comparison_type, StringValue::Get(constant_filter.constant)); + } else { + return filter.CheckStatistics(stats); + } +} + void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t col_idx) { auto &group = GetGroup(state); auto column_id = reader_data.column_ids[col_idx]; @@ -656,7 +731,36 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t c if (stats && filter_entry != reader_data.filters->filters.end()) { bool skip_chunk = false; auto &filter = *filter_entry->second; - auto prune_result = filter.CheckStatistics(*stats); + + FilterPropagateResult prune_result; + if (column_reader->Type().id() == LogicalTypeId::VARCHAR && + group.columns[column_reader->FileIdx()].meta_data.statistics.__isset.min_value && + group.columns[column_reader->FileIdx()].meta_data.statistics.__isset.max_value) { + // our StringStats only store the first 8 bytes of strings (even if Parquet has longer string stats) + // however, when reading remote Parquet files, skipping row groups is really important + // here, we implement a special case to check the full length for string filters + if (filter.filter_type == TableFilterType::CONJUNCTION_AND) { + const auto &and_filter = filter.Cast(); + auto and_result = FilterPropagateResult::FILTER_ALWAYS_TRUE; + for (auto &child_filter : and_filter.child_filters) { + auto child_prune_result = CheckParquetStringFilter( + *stats, group.columns[column_reader->FileIdx()].meta_data.statistics, *child_filter); + if (child_prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { + and_result = FilterPropagateResult::FILTER_ALWAYS_FALSE; + break; + } else if (child_prune_result != and_result) { + and_result = FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + } + prune_result = and_result; + } else { + prune_result = CheckParquetStringFilter( + *stats, group.columns[column_reader->FileIdx()].meta_data.statistics, filter); + } + } else { + prune_result = filter.CheckStatistics(*stats); + } + if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { skip_chunk = true; } @@ -680,7 +784,8 @@ idx_t ParquetReader::NumRowGroups() { return GetFileMetadata()->row_groups.size(); } -void ParquetReader::InitializeScan(ParquetReaderScanState &state, vector groups_to_read) { +void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanState &state, + vector groups_to_read) { state.current_group = -1; state.finished = false; state.group_offset = 0; @@ -700,7 +805,7 @@ void ParquetReader::InitializeScan(ParquetReaderScanState &state, vector } state.thrift_file_proto = CreateThriftFileProtocol(allocator, *state.file_handle, state.prefetch_mode); - state.root_reader = CreateReader(); + state.root_reader = CreateReader(context); state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); } @@ -720,7 +825,9 @@ void FilterIsNull(Vector &v, parquet_filter_t &filter_mask, idx_t count) { filter_mask.reset(); } else { for (idx_t i = 0; i < count; i++) { - filter_mask[i] = filter_mask[i] && !mask.RowIsValid(i); + if (filter_mask.test(i)) { + filter_mask.set(i, !mask.RowIsValid(i)); + } } } } @@ -738,7 +845,9 @@ void FilterIsNotNull(Vector &v, parquet_filter_t &filter_mask, idx_t count) { auto &mask = FlatVector::Validity(v); if (!mask.AllValid()) { for (idx_t i = 0; i < count; i++) { - filter_mask[i] = filter_mask[i] && mask.RowIsValid(i); + if (filter_mask.test(i)) { + filter_mask.set(i, mask.RowIsValid(i)); + } } } } @@ -763,13 +872,15 @@ void TemplatedFilterOperation(Vector &v, T constant, parquet_filter_t &filter_ma if (!mask.AllValid()) { for (idx_t i = 0; i < count; i++) { - if (mask.RowIsValid(i)) { - filter_mask[i] = filter_mask[i] && OP::Operation(v_ptr[i], constant); + if (filter_mask.test(i) && mask.RowIsValid(i)) { + filter_mask.set(i, OP::Operation(v_ptr[i], constant)); } } } else { for (idx_t i = 0; i < count; i++) { - filter_mask[i] = filter_mask[i] && OP::Operation(v_ptr[i], constant); + if (filter_mask.test(i)) { + filter_mask.set(i, OP::Operation(v_ptr[i], constant)); + } } } } @@ -932,7 +1043,7 @@ bool ParquetReader::ScanInternal(ParquetReaderScanState &state, DataChunk &resul uint64_t total_row_group_span = GetGroupSpan(state); - double scan_percentage = (double)(to_scan_compressed_bytes) / total_row_group_span; + double scan_percentage = (double)(to_scan_compressed_bytes) / static_cast(total_row_group_span); if (to_scan_compressed_bytes > total_row_group_span) { throw InvalidInputException( @@ -1049,7 +1160,7 @@ bool ParquetReader::ScanInternal(ParquetReaderScanState &state, DataChunk &resul idx_t sel_size = 0; for (idx_t i = 0; i < this_output_chunk_rows; i++) { - if (filter_mask[i]) { + if (filter_mask.test(i)) { state.sel.set_index(sel_size++, i); } } diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index 67074210..26896c57 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -26,17 +26,17 @@ static unique_ptr CreateNumericStats(const LogicalType &type, // `max_value`. All are optional. such elegance. Value min; Value max; - if (parquet_stats.__isset.min) { - min = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.min).DefaultCastAs(type); - } else if (parquet_stats.__isset.min_value) { + if (parquet_stats.__isset.min_value) { min = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.min_value).DefaultCastAs(type); + } else if (parquet_stats.__isset.min) { + min = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.min).DefaultCastAs(type); } else { min = Value(type); } - if (parquet_stats.__isset.max) { - max = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.max).DefaultCastAs(type); - } else if (parquet_stats.__isset.max_value) { + if (parquet_stats.__isset.max_value) { max = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.max_value).DefaultCastAs(type); + } else if (parquet_stats.__isset.max) { + max = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.max).DefaultCastAs(type); } else { max = Value(type); } @@ -188,7 +188,9 @@ Value ParquetStatisticsUtils::ConvertValue(const LogicalType &type, } case LogicalTypeId::TIME_TZ: { int64_t val; - if (stats.size() == sizeof(int64_t)) { + if (stats.size() == sizeof(int32_t)) { + val = Load(stats_data); + } else if (stats.size() == sizeof(int64_t)) { val = Load(stats_data); } else { throw InternalException("Incorrect stats size for type TIMETZ"); @@ -196,7 +198,7 @@ Value ParquetStatisticsUtils::ConvertValue(const LogicalType &type, if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIME) { // logical type if (schema_ele.logicalType.TIME.unit.__isset.MILLIS) { - return Value::TIMETZ(ParquetIntToTimeMsTZ(val)); + return Value::TIMETZ(ParquetIntToTimeMsTZ(NumericCast(val))); } else if (schema_ele.logicalType.TIME.unit.__isset.MICROS) { return Value::TIMETZ(ParquetIntToTimeTZ(val)); } else if (schema_ele.logicalType.TIME.unit.__isset.NANOS) { @@ -244,6 +246,38 @@ Value ParquetStatisticsUtils::ConvertValue(const LogicalType &type, return Value::TIMESTAMP(timestamp_value); } } + case LogicalTypeId::TIMESTAMP_NS: { + timestamp_ns_t timestamp_value; + if (schema_ele.type == Type::INT96) { + if (stats.size() != sizeof(Int96)) { + throw InternalException("Incorrect stats size for type TIMESTAMP_NS"); + } + timestamp_value = ImpalaTimestampToTimestampNS(Load(stats_data)); + } else { + D_ASSERT(schema_ele.type == Type::INT64); + if (stats.size() != sizeof(int64_t)) { + throw InternalException("Incorrect stats size for type TIMESTAMP_NS"); + } + auto val = Load(stats_data); + if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIMESTAMP) { + // logical type + if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MILLIS) { + timestamp_value = ParquetTimestampMsToTimestampNs(val); + } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.NANOS) { + timestamp_value = ParquetTimestampNsToTimestampNs(val); + } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MICROS) { + timestamp_value = ParquetTimestampUsToTimestampNs(val); + } else { + throw InternalException("Timestamp (NS) logicalType is set but unit is unknown"); + } + } else if (schema_ele.converted_type == duckdb_parquet::format::ConvertedType::TIMESTAMP_MILLIS) { + timestamp_value = ParquetTimestampMsToTimestampNs(val); + } else { + timestamp_value = ParquetTimestampUsToTimestampNs(val); + } + } + return Value::TIMESTAMPNS(timestamp_value); + } default: throw InternalException("Unsupported type for stats %s", type.ToString()); } @@ -315,21 +349,21 @@ unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(con break; case LogicalTypeId::VARCHAR: { auto string_stats = StringStats::CreateEmpty(type); - if (parquet_stats.__isset.min) { - StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); - StringStats::Update(string_stats, parquet_stats.min); - } else if (parquet_stats.__isset.min_value) { + if (parquet_stats.__isset.min_value) { StringColumnReader::VerifyString(parquet_stats.min_value.c_str(), parquet_stats.min_value.size(), true); StringStats::Update(string_stats, parquet_stats.min_value); + } else if (parquet_stats.__isset.min) { + StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); + StringStats::Update(string_stats, parquet_stats.min); } else { return nullptr; } - if (parquet_stats.__isset.max) { - StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); - StringStats::Update(string_stats, parquet_stats.max); - } else if (parquet_stats.__isset.max_value) { + if (parquet_stats.__isset.max_value) { StringColumnReader::VerifyString(parquet_stats.max_value.c_str(), parquet_stats.max_value.size(), true); StringStats::Update(string_stats, parquet_stats.max_value); + } else if (parquet_stats.__isset.max) { + StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); + StringStats::Update(string_stats, parquet_stats.max); } else { return nullptr; } diff --git a/src/duckdb/extension/parquet/parquet_timestamp.cpp b/src/duckdb/extension/parquet/parquet_timestamp.cpp index 3451e266..a0ada7d1 100644 --- a/src/duckdb/extension/parquet/parquet_timestamp.cpp +++ b/src/duckdb/extension/parquet/parquet_timestamp.cpp @@ -14,14 +14,31 @@ static constexpr int64_t JULIAN_TO_UNIX_EPOCH_DAYS = 2440588LL; static constexpr int64_t MILLISECONDS_PER_DAY = 86400000LL; static constexpr int64_t MICROSECONDS_PER_DAY = MILLISECONDS_PER_DAY * 1000LL; static constexpr int64_t NANOSECONDS_PER_MICRO = 1000LL; +static constexpr int64_t NANOSECONDS_PER_DAY = MICROSECONDS_PER_DAY * 1000LL; + +static inline int64_t ImpalaTimestampToDays(const Int96 &impala_timestamp) { + return impala_timestamp.value[2] - JULIAN_TO_UNIX_EPOCH_DAYS; +} static int64_t ImpalaTimestampToMicroseconds(const Int96 &impala_timestamp) { - int64_t days_since_epoch = impala_timestamp.value[2] - JULIAN_TO_UNIX_EPOCH_DAYS; + int64_t days_since_epoch = ImpalaTimestampToDays(impala_timestamp); auto nanoseconds = Load(const_data_ptr_cast(impala_timestamp.value)); auto microseconds = nanoseconds / NANOSECONDS_PER_MICRO; return days_since_epoch * MICROSECONDS_PER_DAY + microseconds; } +static int64_t ImpalaTimestampToNanoseconds(const Int96 &impala_timestamp) { + int64_t days_since_epoch = ImpalaTimestampToDays(impala_timestamp); + auto nanoseconds = Load(const_data_ptr_cast(impala_timestamp.value)); + return days_since_epoch * NANOSECONDS_PER_DAY + nanoseconds; +} + +timestamp_ns_t ImpalaTimestampToTimestampNS(const Int96 &raw_ts) { + timestamp_ns_t result; + result.value = ImpalaTimestampToNanoseconds(raw_ts); + return result; +} + timestamp_t ImpalaTimestampToTimestamp(const Int96 &raw_ts) { auto impala_us = ImpalaTimestampToMicroseconds(raw_ts); return Timestamp::FromEpochMicroSeconds(impala_us); @@ -52,6 +69,30 @@ timestamp_t ParquetTimestampMsToTimestamp(const int64_t &raw_ts) { return Timestamp::FromEpochMs(raw_ts); } +timestamp_ns_t ParquetTimestampMsToTimestampNs(const int64_t &raw_ms) { + timestamp_ns_t input; + input.value = raw_ms; + if (!Timestamp::IsFinite(input)) { + return input; + } + return Timestamp::TimestampNsFromEpochMillis(raw_ms); +} + +timestamp_ns_t ParquetTimestampUsToTimestampNs(const int64_t &raw_us) { + timestamp_ns_t input; + input.value = raw_us; + if (!Timestamp::IsFinite(input)) { + return input; + } + return Timestamp::TimestampNsFromEpochMicros(raw_us); +} + +timestamp_ns_t ParquetTimestampNsToTimestampNs(const int64_t &raw_ns) { + timestamp_ns_t result; + result.value = raw_ns; + return result; +} + timestamp_t ParquetTimestampNsToTimestamp(const int64_t &raw_ts) { timestamp_t input(raw_ts); if (!Timestamp::IsFinite(input)) { diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 10c33861..570b4edd 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -8,6 +8,8 @@ #ifndef DUCKDB_AMALGAMATION #include "duckdb/common/file_system.hpp" #include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/function/table_function.hpp" @@ -15,8 +17,6 @@ #include "duckdb/main/connection.hpp" #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" #endif namespace duckdb { @@ -82,8 +82,8 @@ class MyTransport : public TTransport { WriteStream &serializer; }; -CopyTypeSupport ParquetWriter::DuckDBTypeToParquetTypeInternal(const LogicalType &duckdb_type, - Type::type &parquet_type) { +bool ParquetWriter::TryGetParquetType(const LogicalType &duckdb_type, optional_ptr parquet_type_ptr) { + Type::type parquet_type; switch (duckdb_type.id()) { case LogicalTypeId::BOOLEAN: parquet_type = Type::BOOLEAN; @@ -106,7 +106,7 @@ CopyTypeSupport ParquetWriter::DuckDBTypeToParquetTypeInternal(const LogicalType case LogicalTypeId::UHUGEINT: case LogicalTypeId::HUGEINT: parquet_type = Type::DOUBLE; - return CopyTypeSupport::LOSSY; + break; case LogicalTypeId::ENUM: case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: @@ -151,71 +151,29 @@ CopyTypeSupport ParquetWriter::DuckDBTypeToParquetTypeInternal(const LogicalType break; default: // Anything that is not supported - return CopyTypeSupport::UNSUPPORTED; + return false; } - return CopyTypeSupport::SUPPORTED; + if (parquet_type_ptr) { + *parquet_type_ptr = parquet_type; + } + return true; } Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type) { Type::type result; - auto type_supports = DuckDBTypeToParquetTypeInternal(duckdb_type, result); - if (type_supports == CopyTypeSupport::UNSUPPORTED) { - throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); - } - return result; -} - -CopyTypeSupport ParquetWriter::TypeIsSupported(const LogicalType &type) { - Type::type unused; - auto id = type.id(); - if (id == LogicalTypeId::LIST) { - auto &child_type = ListType::GetChildType(type); - return TypeIsSupported(child_type); - } - if (id == LogicalTypeId::ARRAY) { - auto &child_type = ArrayType::GetChildType(type); - return TypeIsSupported(child_type); - } - if (id == LogicalTypeId::UNION) { - auto count = UnionType::GetMemberCount(type); - for (idx_t i = 0; i < count; i++) { - auto &member_type = UnionType::GetMemberType(type, i); - auto type_support = TypeIsSupported(member_type); - if (type_support != CopyTypeSupport::SUPPORTED) { - return type_support; - } - } - return CopyTypeSupport::SUPPORTED; - } - if (id == LogicalTypeId::STRUCT) { - auto &children = StructType::GetChildTypes(type); - for (auto &child : children) { - auto &child_type = child.second; - auto type_support = TypeIsSupported(child_type); - if (type_support != CopyTypeSupport::SUPPORTED) { - return type_support; - } - } - return CopyTypeSupport::SUPPORTED; - } - if (id == LogicalTypeId::MAP) { - auto &key_type = MapType::KeyType(type); - auto &value_type = MapType::ValueType(type); - auto key_type_support = TypeIsSupported(key_type); - if (key_type_support != CopyTypeSupport::SUPPORTED) { - return key_type_support; - } - auto value_type_support = TypeIsSupported(value_type); - if (value_type_support != CopyTypeSupport::SUPPORTED) { - return value_type_support; - } - return CopyTypeSupport::SUPPORTED; + if (TryGetParquetType(duckdb_type, &result)) { + return result; } - return DuckDBTypeToParquetTypeInternal(type, unused); + throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); } void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::format::SchemaElement &schema_ele) { + if (duckdb_type.IsJSONType()) { + schema_ele.converted_type = ConvertedType::JSON; + schema_ele.__isset.converted_type = true; + return; + } switch (duckdb_type.id()) { case LogicalTypeId::TINYINT: schema_ele.converted_type = ConvertedType::INT_8; @@ -264,7 +222,6 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, break; case LogicalTypeId::TIMESTAMP_TZ: case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_SEC: schema_ele.converted_type = ConvertedType::TIMESTAMP_MICROS; schema_ele.__isset.converted_type = true; @@ -273,6 +230,13 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, schema_ele.logicalType.TIMESTAMP.isAdjustedToUTC = (duckdb_type.id() == LogicalTypeId::TIMESTAMP_TZ); schema_ele.logicalType.TIMESTAMP.unit.__isset.MICROS = true; break; + case LogicalTypeId::TIMESTAMP_NS: + schema_ele.__isset.converted_type = false; + schema_ele.__isset.logicalType = true; + schema_ele.logicalType.__isset.TIMESTAMP = true; + schema_ele.logicalType.TIMESTAMP.isAdjustedToUTC = false; + schema_ele.logicalType.TIMESTAMP.unit.__isset.NANOS = true; + break; case LogicalTypeId::TIMESTAMP_MS: schema_ele.converted_type = ConvertedType::TIMESTAMP_MILLIS; schema_ele.__isset.converted_type = true; @@ -321,7 +285,7 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, uint32_t ParquetWriter::Write(const duckdb_apache::thrift::TBase &object) { if (encryption_config) { - return ParquetCrypto::Write(object, *protocol, encryption_config->GetFooterKey()); + return ParquetCrypto::Write(object, *protocol, encryption_config->GetFooterKey(), *encryption_util); } else { return object.write(protocol.get()); } @@ -329,7 +293,8 @@ uint32_t ParquetWriter::Write(const duckdb_apache::thrift::TBase &object) { uint32_t ParquetWriter::WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size) { if (encryption_config) { - return ParquetCrypto::WriteData(*protocol, buffer, buffer_size, encryption_config->GetFooterKey()); + return ParquetCrypto::WriteData(*protocol, buffer, buffer_size, encryption_config->GetFooterKey(), + *encryption_util); } else { protocol->getTransport()->write(buffer, buffer_size); return buffer_size; @@ -349,18 +314,27 @@ void VerifyUniqueNames(const vector &names) { #endif } -ParquetWriter::ParquetWriter(FileSystem &fs, string file_name_p, vector types_p, vector names_p, - CompressionCodec::type codec, ChildFieldIDs field_ids_p, +ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, + vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, const vector> &kv_metadata, shared_ptr encryption_config_p, - double dictionary_compression_ratio_threshold_p, optional_idx compression_level_p) + double dictionary_compression_ratio_threshold_p, optional_idx compression_level_p, + bool debug_use_openssl_p) : file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), encryption_config(std::move(encryption_config_p)), - dictionary_compression_ratio_threshold(dictionary_compression_ratio_threshold_p) { + dictionary_compression_ratio_threshold(dictionary_compression_ratio_threshold_p), + debug_use_openssl(debug_use_openssl_p) { // initialize the file writer writer = make_uniq(fs, file_name.c_str(), FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); if (encryption_config) { + auto &config = DBConfig::GetConfig(context); + if (config.encryption_util && debug_use_openssl) { + // Use OpenSSL + encryption_util = config.encryption_util; + } else { + encryption_util = make_shared_ptr(); + } // encrypted parquet files start with the string "PARE" writer->WriteData(const_data_ptr_cast("PARE"), 4); // we only support this one for now, not "AES_GCM_CTR_V1" @@ -403,7 +377,7 @@ ParquetWriter::ParquetWriter(FileSystem &fs, string file_name_p, vector(sql_types.size()); file_meta_data.schema[0].__isset.num_children = true; file_meta_data.schema[0].repetition_type = duckdb_parquet::format::FieldRepetitionType::REQUIRED; file_meta_data.schema[0].__isset.repetition_type = true; @@ -413,8 +387,8 @@ ParquetWriter::ParquetWriter(FileSystem &fs, string file_name_p, vector schema_path; for (idx_t i = 0; i < sql_types.size(); i++) { - column_writers.push_back(ColumnWriter::CreateWriterRecursive(file_meta_data.schema, *this, sql_types[i], - unique_names[i], schema_path, &field_ids)); + column_writers.push_back(ColumnWriter::CreateWriterRecursive( + context, file_meta_data.schema, *this, sql_types[i], unique_names[i], schema_path, &field_ids)); } } @@ -428,8 +402,8 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro // set up a new row group for this chunk collection auto &row_group = result.row_group; - row_group.num_rows = buffer.Count(); - row_group.total_byte_size = buffer.SizeInBytes(); + row_group.num_rows = NumericCast(buffer.Count()); + row_group.total_byte_size = NumericCast(buffer.SizeInBytes()); row_group.__isset.file_offset = true; auto &states = result.states; @@ -460,6 +434,11 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro } } + // Reserving these once at the start really pays off + for (auto &write_state : write_states) { + write_state->definition_levels.reserve(buffer.Count()); + } + for (auto &chunk : buffer.Chunks({column_ids})) { for (idx_t i = 0; i < next; i++) { col_writers[i].get().Prepare(*write_states[i], nullptr, chunk.data[i], chunk.size()); @@ -526,7 +505,7 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { if (states.empty()) { throw InternalException("Attempting to flush a row group with no rows"); } - row_group.file_offset = writer->GetTotalWritten(); + row_group.file_offset = NumericCast(writer->GetTotalWritten()); for (idx_t col_idx = 0; col_idx < states.size(); col_idx++) { const auto &col_writer = column_writers[col_idx]; auto write_state = std::move(states[col_idx]); @@ -555,7 +534,7 @@ void ParquetWriter::Flush(ColumnDataCollection &buffer) { } void ParquetWriter::Finalize() { - auto start_offset = writer->GetTotalWritten(); + const auto start_offset = writer->GetTotalWritten(); if (encryption_config) { // Crypto metadata is written unencrypted FileCryptoMetaData crypto_metadata; @@ -565,6 +544,12 @@ void ParquetWriter::Finalize() { crypto_metadata.__set_encryption_algorithm(alg); crypto_metadata.write(protocol.get()); } + + // Add geoparquet metadata to the file metadata + if (geoparquet_data) { + geoparquet_data->Write(file_meta_data); + } + Write(file_meta_data); writer->Write(writer->GetTotalWritten() - start_offset); @@ -578,8 +563,15 @@ void ParquetWriter::Finalize() { } // flush to disk - writer->Sync(); + writer->Close(); writer.reset(); } +GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { + if (!geoparquet_data) { + geoparquet_data = make_uniq(); + } + return *geoparquet_data; +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp index cc5d1445..e6aeac02 100644 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ b/src/duckdb/extension/parquet/serialize_parquet.cpp @@ -71,6 +71,7 @@ void ParquetOptions::Serialize(Serializer &serializer) const { serializer.WriteProperty(102, "file_options", file_options); serializer.WritePropertyWithDefault>(103, "schema", schema); serializer.WritePropertyWithDefault>(104, "encryption_config", encryption_config, nullptr); + serializer.WritePropertyWithDefault(105, "debug_use_openssl", debug_use_openssl, true); } ParquetOptions ParquetOptions::Deserialize(Deserializer &deserializer) { @@ -79,7 +80,8 @@ ParquetOptions ParquetOptions::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault(101, "file_row_number", result.file_row_number); deserializer.ReadProperty(102, "file_options", result.file_options); deserializer.ReadPropertyWithDefault>(103, "schema", result.schema); - deserializer.ReadPropertyWithDefault>(104, "encryption_config", result.encryption_config, nullptr); + deserializer.ReadPropertyWithExplicitDefault>(104, "encryption_config", result.encryption_config, nullptr); + deserializer.ReadPropertyWithExplicitDefault(105, "debug_use_openssl", result.debug_use_openssl, true); return result; } diff --git a/src/duckdb/extension/parquet/zstd_file_system.cpp b/src/duckdb/extension/parquet/zstd_file_system.cpp index 42b602d5..5a630b05 100644 --- a/src/duckdb/extension/parquet/zstd_file_system.cpp +++ b/src/duckdb/extension/parquet/zstd_file_system.cpp @@ -100,7 +100,7 @@ void ZstdStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t u sd.out_buff_start = sd.out_buff.get(); } uncompressed_data += input_consumed; - remaining -= input_consumed; + remaining -= UnsafeNumericCast(input_consumed); } } @@ -160,6 +160,10 @@ class ZStdFile : public CompressedFile { Initialize(write); } + FileCompressionType GetFileCompressionType() override { + return FileCompressionType::ZSTD; + } + ZStdFileSystem zstd_fs; }; diff --git a/src/duckdb/src/catalog/catalog.cpp b/src/duckdb/src/catalog/catalog.cpp index 4d334a4d..08a01d93 100644 --- a/src/duckdb/src/catalog/catalog.cpp +++ b/src/duckdb/src/catalog/catalog.cpp @@ -315,7 +315,6 @@ struct CatalogEntryLookup { // Generic //===--------------------------------------------------------------------===// void Catalog::DropEntry(ClientContext &context, DropInfo &info) { - ModifyCatalog(); if (info.type == CatalogType::SCHEMA_ENTRY) { // DROP SCHEMA DropSchema(context, info); @@ -364,7 +363,7 @@ SimilarCatalogEntry Catalog::SimilarEntryInSchemas(ClientContext &context, const // no similar entry found continue; } - if (!result.Found() || result.distance > entry.distance) { + if (!result.Found() || result.score < entry.score) { result = entry; result.schema = &schema; } @@ -563,10 +562,19 @@ CatalogException Catalog::CreateMissingEntryException(ClientContext &context, co reference_set_t unseen_schemas; auto &db_manager = DatabaseManager::Get(context); auto databases = db_manager.GetDatabases(context); + auto &config = DBConfig::GetConfig(context); + + auto max_schema_count = config.options.catalog_error_max_schemas; for (auto database : databases) { + if (unseen_schemas.size() >= max_schema_count) { + break; + } auto &catalog = database.get().GetCatalog(); auto current_schemas = catalog.GetAllSchemas(context); for (auto ¤t_schema : current_schemas) { + if (unseen_schemas.size() >= max_schema_count) { + break; + } unseen_schemas.insert(current_schema.get()); } } @@ -626,9 +634,12 @@ CatalogException Catalog::CreateMissingEntryException(ClientContext &context, co return CatalogException(error_message); } + // entries in other schemas get a penalty + // however, if there is an exact match in another schema, we will always show it + static constexpr const double UNSEEN_PENALTY = 0.2; auto unseen_entry = SimilarEntryInSchemas(context, entry_name, type, unseen_schemas); string did_you_mean; - if (unseen_entry.Found() && unseen_entry.distance < entry.distance) { + if (unseen_entry.Found() && (unseen_entry.score == 1.0 || unseen_entry.score - UNSEEN_PENALTY > entry.score)) { // the closest matching entry requires qualification as it is not in the default search path // check how to minimally qualify this entry auto catalog_name = unseen_entry.schema->catalog.GetName(); @@ -884,8 +895,6 @@ vector> Catalog::GetAllSchemas(ClientContext &cont } void Catalog::Alter(CatalogTransaction transaction, AlterInfo &info) { - ModifyCatalog(); - if (transaction.HasContext()) { auto lookup = LookupEntry(transaction.GetContext(), info.GetCatalogType(), info.schema, info.name, info.if_not_found); @@ -910,17 +919,6 @@ vector Catalog::GetMetadataInfo(ClientContext &context) { void Catalog::Verify() { } -//===--------------------------------------------------------------------===// -// Catalog Version -//===--------------------------------------------------------------------===// -idx_t Catalog::GetCatalogVersion() { - return GetDatabase().GetDatabaseManager().catalog_version; -} - -idx_t Catalog::ModifyCatalog() { - return GetDatabase().GetDatabaseManager().ModifyCatalog(); -} - bool Catalog::IsSystemCatalog() const { return db.IsSystem(); } diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp index aa05a866..3c8f383a 100644 --- a/src/duckdb/src/catalog/catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry.cpp @@ -1,8 +1,11 @@ #include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/parser/parsed_data/create_info.hpp" + #include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/parser/parsed_data/create_info.hpp" namespace duckdb { @@ -12,7 +15,7 @@ CatalogEntry::CatalogEntry(CatalogType type, string name_p, idx_t oid) } CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, string name_p) - : CatalogEntry(type, std::move(name_p), catalog.ModifyCatalog()) { + : CatalogEntry(type, std::move(name_p), catalog.GetDatabase().GetDatabaseManager().NextOid()) { } CatalogEntry::~CatalogEntry() { diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp index b6da4918..be047bfe 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp @@ -1,6 +1,7 @@ #include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" namespace duckdb { @@ -12,28 +13,30 @@ IndexDataTableInfo::~IndexDataTableInfo() { if (!info) { return; } + // FIXME: this should happen differently. info->GetIndexes().RemoveIndex(index_name); } -DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) - : IndexCatalogEntry(catalog, schema, info) { +DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, + TableCatalogEntry &table_p) + : IndexCatalogEntry(catalog, schema, create_info), initial_index_size(0) { + auto &table = table_p.Cast(); + auto &storage = table.GetStorage(); + info = make_shared_ptr(storage.GetDataTableInfo(), name); +} + +DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, + shared_ptr storage_info) + : IndexCatalogEntry(catalog, schema, create_info), info(std::move(storage_info)), initial_index_size(0) { } unique_ptr DuckIndexEntry::Copy(ClientContext &context) const { auto info_copy = GetInfo(); auto &cast_info = info_copy->Cast(); - auto result = make_uniq(catalog, schema, cast_info); - result->info = info; + auto result = make_uniq(catalog, schema, cast_info, info); result->initial_index_size = initial_index_size; - for (auto &expr : expressions) { - result->expressions.push_back(expr->Copy()); - } - for (auto &expr : parsed_expressions) { - result->parsed_expressions.push_back(expr->Copy()); - } - return std::move(result); } diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp index b9a3afec..42dea06f 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp @@ -1,40 +1,44 @@ #include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/catalog/default/default_functions.hpp" -#include "duckdb/catalog/default/default_types.hpp" -#include "duckdb/catalog/default/default_views.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/default/default_functions.hpp" +#include "duckdb/catalog/default/default_table_functions.hpp" +#include "duckdb/catalog/default/default_types.hpp" +#include "duckdb/catalog/default/default_views.hpp" #include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database.hpp" #include "duckdb/parser/constraints/foreign_key_constraint.hpp" #include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/parser/parsed_data/create_collation_info.hpp" #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" #include "duckdb/parser/parsed_data/create_index_info.hpp" #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" #include "duckdb/parser/parsed_data/create_schema_info.hpp" #include "duckdb/parser/parsed_data/create_sequence_info.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/parser/parsed_data/create_view_info.hpp" #include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" +#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" +#include "duckdb/storage/data_table.hpp" +#include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/meta_transaction.hpp" -#include "duckdb/main/attached_database.hpp" namespace duckdb { @@ -78,7 +82,8 @@ static void LazyLoadIndexes(ClientContext &context, CatalogEntry &entry) { DuckSchemaEntry::DuckSchemaEntry(Catalog &catalog, CreateSchemaInfo &info) : SchemaCatalogEntry(catalog, info), tables(catalog, make_uniq(catalog, *this)), - indexes(catalog), table_functions(catalog), copy_functions(catalog), pragma_functions(catalog), + indexes(catalog), table_functions(catalog, make_uniq(catalog, *this)), + copy_functions(catalog), pragma_functions(catalog), functions(catalog, make_uniq(catalog, *this)), sequences(catalog), collations(catalog), types(catalog, make_uniq(catalog, *this)) { } @@ -125,6 +130,7 @@ optional_ptr DuckSchemaEntry::AddEntryInternal(CatalogTransaction throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", entry_name, CatalogTypeToString(old_entry->type), CatalogTypeToString(entry_type)); } + OnDropEntry(transaction, *old_entry); (void)set.DropEntry(transaction, entry_name, false, entry->internal); } } @@ -240,7 +246,7 @@ optional_ptr DuckSchemaEntry::CreateIndex(CatalogTransaction trans throw CatalogException("An index with the name " + info.index_name + " already exists!"); } - auto index = make_uniq(catalog, *this, info); + auto index = make_uniq(catalog, *this, info, table); auto dependencies = index->dependencies; return AddEntryInternal(transaction, std::move(index), info.on_conflict, dependencies); } @@ -316,13 +322,14 @@ void DuckSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { // if this is a index or table with indexes, initialize any unknown index instances LazyLoadIndexes(context, *existing_entry); - // if there is a foreign key constraint, get that information vector> fk_arrays; if (existing_entry->type == CatalogType::TABLE_ENTRY) { - FindForeignKeyInformation(existing_entry->Cast(), AlterForeignKeyType::AFT_DELETE, - fk_arrays); + // if there is a foreign key constraint, get that information + auto &table_entry = existing_entry->Cast(); + FindForeignKeyInformation(table_entry, AlterForeignKeyType::AFT_DELETE, fk_arrays); } + OnDropEntry(transaction, *existing_entry); if (!set.DropEntry(transaction, info.name, info.cascade, info.allow_drop_internal)) { throw InternalException("Could not drop element because of an internal error"); } @@ -334,6 +341,19 @@ void DuckSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { } } +void DuckSchemaEntry::OnDropEntry(CatalogTransaction transaction, CatalogEntry &entry) { + if (!transaction.transaction) { + return; + } + if (entry.type != CatalogType::TABLE_ENTRY) { + return; + } + // if we have transaction local insertions for this table - clear them + auto &table_entry = entry.Cast(); + auto &local_storage = LocalStorage::Get(transaction.transaction->Cast()); + local_storage.DropTable(table_entry.GetStorage()); +} + optional_ptr DuckSchemaEntry::GetEntry(CatalogTransaction transaction, CatalogType type, const string &name) { return GetCatalogSet(type).GetEntry(transaction, name); diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index 144fe314..be1041f9 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -1,10 +1,13 @@ #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/common/enum_util.hpp" +#include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/index_map.hpp" #include "duckdb/execution/index/art/art.hpp" #include "duckdb/function/table/table_scan.hpp" +#include "duckdb/main/database.hpp" #include "duckdb/parser/constraints/list.hpp" +#include "duckdb/parser/parsed_data/comment_on_column_info.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/constraints/bound_check_constraint.hpp" @@ -13,7 +16,6 @@ #include "duckdb/planner/constraints/bound_unique_constraint.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_binder/alter_binder.hpp" -#include "duckdb/planner/filter/null_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_update.hpp" @@ -21,13 +23,11 @@ #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/storage_manager.hpp" #include "duckdb/storage/table_storage_info.hpp" -#include "duckdb/common/exception/transaction_exception.hpp" -#include "duckdb/parser/parsed_data/comment_on_column_info.hpp" namespace duckdb { void AddDataTableIndex(DataTable &storage, const ColumnList &columns, const vector &keys, - IndexConstraintType constraint_type, const IndexStorageInfo &info = IndexStorageInfo()) { + IndexConstraintType constraint_type, const IndexStorageInfo &info) { // fetch types and create expressions for the index from the columns vector column_ids; @@ -54,7 +54,7 @@ void AddDataTableIndex(DataTable &storage, const ColumnList &columns, const vect } void AddDataTableIndex(DataTable &storage, const ColumnList &columns, vector &keys, - IndexConstraintType constraint_type, const IndexStorageInfo &info = IndexStorageInfo()) { + IndexConstraintType constraint_type, const IndexStorageInfo &info) { vector new_keys; new_keys.reserve(keys.size()); for (auto &logical_key : keys) { @@ -63,12 +63,17 @@ void AddDataTableIndex(DataTable &storage, const ColumnList &columns, vector &create_info, - idx_t idx) { +IndexStorageInfo GetIndexInfo(const IndexConstraintType &constraint_type, const bool v1_0_0_storage, + unique_ptr &create_info, const idx_t identifier) { auto &create_table_info = create_info->Cast(); auto constraint_name = EnumUtil::ToString(constraint_type) + "_"; - return IndexStorageInfo(constraint_name + create_table_info.table + "_" + to_string(idx)); + auto name = constraint_name + create_table_info.table + "_" + to_string(identifier); + IndexStorageInfo info(name); + if (!v1_0_0_storage) { + info.options.emplace("v1_0_0_storage", v1_0_0_storage); + } + return info; } vector GetUniqueConstraintKeys(const ColumnList &columns, const UniqueConstraint &constraint) { @@ -88,66 +93,72 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), column_dependency_manager(std::move(info.column_dependency_manager)) { - if (!storage) { - // create the physical storage - vector storage_columns; - for (auto &col_def : columns.Physical()) { - storage_columns.push_back(col_def.Copy()); + if (storage) { + if (!info.indexes.empty()) { + storage->SetIndexStorageInfo(std::move(info.indexes)); } - storage = - make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), - schema.name, name, std::move(storage_columns), std::move(info.data)); - - // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints - idx_t indexes_idx = 0; - for (idx_t i = 0; i < constraints.size(); i++) { - auto &constraint = constraints[i]; - if (constraint->type == ConstraintType::UNIQUE) { - - // unique constraint: create a unique index - auto &unique = constraint->Cast(); - IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; - if (unique.is_primary_key) { - constraint_type = IndexConstraintType::PRIMARY; - } - auto unique_keys = GetUniqueConstraintKeys(columns, unique); + return; + } + + // create the physical storage + vector storage_columns; + for (auto &col_def : columns.Physical()) { + storage_columns.push_back(col_def.Copy()); + } + storage = make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), + schema.name, name, std::move(storage_columns), std::move(info.data)); + + // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints + idx_t indexes_idx = 0; + for (idx_t i = 0; i < constraints.size(); i++) { + auto &constraint = constraints[i]; + if (constraint->type == ConstraintType::UNIQUE) { + // unique constraint: create a unique index + auto &unique = constraint->Cast(); + IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; + if (unique.is_primary_key) { + constraint_type = IndexConstraintType::PRIMARY; + } + auto unique_keys = GetUniqueConstraintKeys(columns, unique); + if (info.indexes.empty()) { + auto index_storage_info = GetIndexInfo(constraint_type, false, info.base, i); + AddDataTableIndex(*storage, columns, unique_keys, constraint_type, index_storage_info); + continue; + } + + // We read the index from an old storage version applying a dummy name. + if (info.indexes[indexes_idx].name.empty()) { + auto name_info = GetIndexInfo(constraint_type, true, info.base, i); + info.indexes[indexes_idx].name = name_info.name; + } + + // now add the index + AddDataTableIndex(*storage, columns, unique_keys, constraint_type, info.indexes[indexes_idx++]); + continue; + } + + if (constraint->type == ConstraintType::FOREIGN_KEY) { + // foreign key constraint: create a foreign key index + auto &bfk = constraint->Cast(); + if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || + bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { + if (info.indexes.empty()) { - AddDataTableIndex(*storage, columns, unique_keys, constraint_type, - GetIndexInfo(constraint_type, info.base, i)); - } else { - // we read the index from an old storage version, so we have to apply a dummy name - if (info.indexes[indexes_idx].name.empty()) { - auto name_info = GetIndexInfo(constraint_type, info.base, i); - info.indexes[indexes_idx].name = name_info.name; - } - - // now add the index - AddDataTableIndex(*storage, columns, unique_keys, constraint_type, info.indexes[indexes_idx++]); + auto constraint_type = IndexConstraintType::FOREIGN; + auto index_storage_info = GetIndexInfo(constraint_type, false, info.base, i); + AddDataTableIndex(*storage, columns, bfk.info.fk_keys, constraint_type, index_storage_info); + continue; } - } else if (constraint->type == ConstraintType::FOREIGN_KEY) { - // foreign key constraint: create a foreign key index - auto &bfk = constraint->Cast(); - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - - if (info.indexes.empty()) { - auto constraint_type = IndexConstraintType::FOREIGN; - AddDataTableIndex(*storage, columns, bfk.info.fk_keys, constraint_type, - GetIndexInfo(constraint_type, info.base, i)); - - } else { - // we read the index from an old storage version, so we have to apply a dummy name - if (info.indexes[indexes_idx].name.empty()) { - auto name_info = GetIndexInfo(IndexConstraintType::FOREIGN, info.base, i); - info.indexes[indexes_idx].name = name_info.name; - } - - // now add the index - AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN, - info.indexes[indexes_idx++]); - } + // We read the index from an old storage version applying a dummy name. + if (info.indexes[indexes_idx].name.empty()) { + auto name_info = GetIndexInfo(IndexConstraintType::FOREIGN, true, info.base, i); + info.indexes[indexes_idx].name = name_info.name; } + + // now add the index + AddDataTableIndex(*storage, columns, bfk.info.fk_keys, IndexConstraintType::FOREIGN, + info.indexes[indexes_idx++]); } } } @@ -169,20 +180,25 @@ unique_ptr DuckTableEntry::GetStatistics(ClientContext &context, } unique_ptr DuckTableEntry::AlterEntry(CatalogTransaction transaction, AlterInfo &info) { - if (transaction.context) { - return AlterEntry(*transaction.context, info); - } - if (info.type == AlterType::ALTER_TABLE) { - auto &table_info = info.Cast(); - if (table_info.alter_table_type == AlterTableType::FOREIGN_KEY_CONSTRAINT) { - auto &foreign_key_constraint_info = table_info.Cast(); - if (foreign_key_constraint_info.type == AlterForeignKeyType::AFT_ADD) { - // for checkpoint loading we support adding foreign key constraints without a client context - return AddForeignKeyConstraint(nullptr, foreign_key_constraint_info); - } - } + if (transaction.HasContext()) { + return AlterEntry(transaction.GetContext(), info); + } + if (info.type != AlterType::ALTER_TABLE) { + return CatalogEntry::AlterEntry(transaction, info); } - return CatalogEntry::AlterEntry(transaction, info); + + auto &table_info = info.Cast(); + if (table_info.alter_table_type != AlterTableType::FOREIGN_KEY_CONSTRAINT) { + return CatalogEntry::AlterEntry(transaction, info); + } + + auto &foreign_key_constraint_info = table_info.Cast(); + if (foreign_key_constraint_info.type != AlterForeignKeyType::AFT_ADD) { + return CatalogEntry::AlterEntry(transaction, info); + } + + // We add foreign key constraints without a client context during checkpoint loading. + return AddForeignKeyConstraint(nullptr, foreign_key_constraint_info); } unique_ptr DuckTableEntry::AlterEntry(ClientContext &context, AlterInfo &info) { @@ -758,7 +774,6 @@ unique_ptr DuckTableEntry::AddForeignKeyConstraint(optional_ptr(catalog, schema, *bound_create_info, storage); } @@ -783,7 +798,6 @@ unique_ptr DuckTableEntry::DropForeignKeyConstraint(ClientContext auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); } diff --git a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp index 9e704dcd..2c5cb9ae 100644 --- a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -9,6 +9,14 @@ IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schem this->temporary = info.temporary; this->dependencies = info.dependencies; this->comment = info.comment; + for (auto &expr : expressions) { + D_ASSERT(expr); + expressions.push_back(expr->Copy()); + } + for (auto &parsed_expr : info.parsed_expressions) { + D_ASSERT(parsed_expr); + parsed_expressions.push_back(parsed_expr->Copy()); + } } unique_ptr IndexCatalogEntry::GetInfo() const { @@ -42,12 +50,12 @@ string IndexCatalogEntry::ToSQL() const { return info->ToString(); } -bool IndexCatalogEntry::IsUnique() { +bool IndexCatalogEntry::IsUnique() const { return (index_constraint_type == IndexConstraintType::UNIQUE || index_constraint_type == IndexConstraintType::PRIMARY); } -bool IndexCatalogEntry::IsPrimary() { +bool IndexCatalogEntry::IsPrimary() const { return (index_constraint_type == IndexConstraintType::PRIMARY); } diff --git a/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp index 473b07f4..6aa9a52c 100644 --- a/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp @@ -6,9 +6,9 @@ namespace duckdb { MacroCatalogEntry::MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) : FunctionEntry( - (info.function->type == MacroType::SCALAR_MACRO ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY), + (info.macros[0]->type == MacroType::SCALAR_MACRO ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY), catalog, schema, info), - function(std::move(info.function)) { + macros(std::move(info.macros)) { this->temporary = info.temporary; this->internal = info.internal; this->dependencies = info.dependencies; @@ -43,11 +43,18 @@ unique_ptr MacroCatalogEntry::GetInfo() const { info->catalog = catalog.GetName(); info->schema = schema.name; info->name = name; - info->function = function->Copy(); + for (auto &function : macros) { + info->macros.push_back(function->Copy()); + } info->dependencies = dependencies; info->comment = comment; info->tags = tags; return std::move(info); } +string MacroCatalogEntry::ToSQL() const { + auto create_info = GetInfo(); + return create_info->ToString(); +} + } // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp index 5789cff8..3c8e7c70 100644 --- a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp @@ -30,9 +30,9 @@ SimilarCatalogEntry SchemaCatalogEntry::GetSimilarEntry(CatalogTransaction trans const string &name) { SimilarCatalogEntry result; Scan(transaction.GetContext(), type, [&](CatalogEntry &entry) { - auto ldist = StringUtil::SimilarityScore(entry.name, name); - if (ldist < result.distance) { - result.distance = ldist; + auto entry_score = StringUtil::SimilarityRating(entry.name, name); + if (entry_score > result.score) { + result.score = entry_score; result.name = entry.name; } }); diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 33f6c1e4..abb66b7e 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -32,7 +32,7 @@ bool TableCatalogEntry::HasGeneratedColumns() const { return columns.LogicalColumnCount() != columns.PhysicalColumnCount(); } -LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) { +LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) const { auto entry = columns.GetColumnIndex(column_name); if (!entry.IsValid()) { if (if_exists) { @@ -43,15 +43,15 @@ LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exis return entry; } -bool TableCatalogEntry::ColumnExists(const string &name) { +bool TableCatalogEntry::ColumnExists(const string &name) const { return columns.ColumnExists(name); } -const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) { +const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) const { return columns.GetColumn(name); } -vector TableCatalogEntry::GetTypes() { +vector TableCatalogEntry::GetTypes() const { vector types; for (auto &col : columns.Physical()) { types.push_back(col.Type()); @@ -185,7 +185,7 @@ const ColumnList &TableCatalogEntry::GetColumns() const { return columns; } -const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) { +const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) const { return columns.GetColumn(idx); } @@ -226,8 +226,8 @@ static void BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, LogicalP update.expressions.push_back(make_uniq( column.Type(), ColumnBinding(proj.table_index, proj.expressions.size()))); proj.expressions.push_back(make_uniq( - column.Type(), ColumnBinding(get.table_index, get.column_ids.size()))); - get.column_ids.push_back(check_column_id.index); + column.Type(), ColumnBinding(get.table_index, get.GetColumnIds().size()))); + get.AddColumnId(check_column_id.index); update.columns.push_back(check_column_id); } } diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index 716be598..f3f7ef1b 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -11,6 +11,7 @@ #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/parsed_data/alter_table_info.hpp" #include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" #include "duckdb/transaction/transaction_manager.hpp" #include "duckdb/catalog/dependency_list.hpp" #include "duckdb/common/exception/transaction_exception.hpp" @@ -186,8 +187,8 @@ bool CatalogSet::CreateEntryInternal(CatalogTransaction transaction, const strin map.UpdateEntry(std::move(value)); // push the old entry in the undo buffer for this transaction if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(value_ptr->Child()); + DuckTransactionManager::Get(GetCatalog().GetAttached()) + .PushCatalogEntry(*transaction.transaction, value_ptr->Child()); } return true; } @@ -362,8 +363,8 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, serializer.WriteProperty(101, "alter_info", &alter_info); serializer.End(); - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(new_entry->Child(), stream.GetData(), stream.GetPosition()); + DuckTransactionManager::Get(GetCatalog().GetAttached()) + .PushCatalogEntry(*transaction.transaction, new_entry->Child(), stream.GetData(), stream.GetPosition()); } read_lock.unlock(); @@ -414,8 +415,8 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string // push the old entry in the undo buffer for this transaction if (transaction.transaction) { - auto &dtransaction = transaction.transaction->Cast(); - dtransaction.PushCatalogEntry(value_ptr->Child()); + DuckTransactionManager::Get(GetCatalog().GetAttached()) + .PushCatalogEntry(*transaction.transaction, value_ptr->Child()); } return true; } @@ -507,9 +508,9 @@ SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, con SimilarCatalogEntry result; for (auto &kv : map.Entries()) { - auto ldist = StringUtil::SimilarityScore(kv.first, name); - if (ldist < result.distance) { - result.distance = ldist; + auto entry_score = StringUtil::SimilarityRating(kv.first, name); + if (entry_score > result.score) { + result.score = entry_score; result.name = kv.first; } } @@ -523,14 +524,10 @@ optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction tra // no defaults either: return null return nullptr; } + read_lock.unlock(); // this catalog set has a default map defined // check if there is a default entry that we can create with this name - if (!transaction.context) { - // no context - cannot create default entry - return nullptr; - } - read_lock.unlock(); - auto entry = defaults->CreateDefaultEntry(*transaction.context, name); + auto entry = defaults->CreateDefaultEntry(transaction, name); read_lock.lock(); if (!entry) { @@ -604,12 +601,10 @@ void CatalogSet::Undo(CatalogEntry &entry) { // This was the root of the entry chain map.DropEntry(entry); } - // we mark the catalog as being modified, since this action can lead to e.g. tables being dropped - catalog.ModifyCatalog(); } void CatalogSet::CreateDefaultEntries(CatalogTransaction transaction, unique_lock &read_lock) { - if (!defaults || defaults->created_all_entries || !transaction.context) { + if (!defaults || defaults->created_all_entries) { return; } // this catalog set has a default set defined: @@ -620,7 +615,7 @@ void CatalogSet::CreateDefaultEntries(CatalogTransaction transaction, unique_loc // we unlock during the CreateEntry, since it might reference other catalog sets... // specifically for views this can happen since the view will be bound read_lock.unlock(); - auto entry = defaults->CreateDefaultEntry(*transaction.context, default_entry); + auto entry = defaults->CreateDefaultEntry(transaction, default_entry); if (!entry) { throw InternalException("Failed to create default entry for %s", default_entry); } diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp index 60ac77ca..b4f7deca 100644 --- a/src/duckdb/src/catalog/default/default_functions.cpp +++ b/src/duckdb/src/catalog/default/default_functions.cpp @@ -10,211 +10,224 @@ namespace duckdb { static const DefaultMacro internal_macros[] = { - {DEFAULT_SCHEMA, "current_role", {nullptr}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_user", {nullptr}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_catalog", {nullptr}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) - {DEFAULT_SCHEMA, "user", {nullptr}, "current_user"}, // equivalent to current_user - {DEFAULT_SCHEMA, "session_user", {nullptr}, "'duckdb'"}, // session user name - {"pg_catalog", "inet_client_addr", {nullptr}, "NULL"}, // address of the remote connection - {"pg_catalog", "inet_client_port", {nullptr}, "NULL"}, // port of the remote connection - {"pg_catalog", "inet_server_addr", {nullptr}, "NULL"}, // address of the local connection - {"pg_catalog", "inet_server_port", {nullptr}, "NULL"}, // port of the local connection - {"pg_catalog", "pg_my_temp_schema", {nullptr}, "0"}, // OID of session's temporary schema, or 0 if none - {"pg_catalog", "pg_is_other_temp_schema", {"schema_id", nullptr}, "false"}, // is schema another session's temporary schema? - - {"pg_catalog", "pg_conf_load_time", {nullptr}, "current_timestamp"}, // configuration load time - {"pg_catalog", "pg_postmaster_start_time", {nullptr}, "current_timestamp"}, // server start time - - {"pg_catalog", "pg_typeof", {"expression", nullptr}, "lower(typeof(expression))"}, // get the data type of any value - - {"pg_catalog", "current_database", {nullptr}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) - {"pg_catalog", "current_query", {nullptr}, "current_query()"}, // the currently executing query (NULL if not inside a plpgsql function) - {"pg_catalog", "current_schema", {nullptr}, "current_schema()"}, // name of current schema - {"pg_catalog", "current_schemas", {"include_implicit"}, "current_schemas(include_implicit)"}, // names of schemas in search path + {DEFAULT_SCHEMA, "current_role", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_user", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_catalog", {nullptr}, {{nullptr, nullptr}}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) + {DEFAULT_SCHEMA, "user", {nullptr}, {{nullptr, nullptr}}, "current_user"}, // equivalent to current_user + {DEFAULT_SCHEMA, "session_user", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // session user name + {"pg_catalog", "inet_client_addr", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // address of the remote connection + {"pg_catalog", "inet_client_port", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // port of the remote connection + {"pg_catalog", "inet_server_addr", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // address of the local connection + {"pg_catalog", "inet_server_port", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // port of the local connection + {"pg_catalog", "pg_my_temp_schema", {nullptr}, {{nullptr, nullptr}}, "0"}, // OID of session's temporary schema, or 0 if none + {"pg_catalog", "pg_is_other_temp_schema", {"schema_id", nullptr}, {{nullptr, nullptr}}, "false"}, // is schema another session's temporary schema? + + {"pg_catalog", "pg_conf_load_time", {nullptr}, {{nullptr, nullptr}}, "current_timestamp"}, // configuration load time + {"pg_catalog", "pg_postmaster_start_time", {nullptr}, {{nullptr, nullptr}}, "current_timestamp"}, // server start time + + {"pg_catalog", "pg_typeof", {"expression", nullptr}, {{nullptr, nullptr}}, "lower(typeof(expression))"}, // get the data type of any value + + {"pg_catalog", "current_database", {nullptr}, {{nullptr, nullptr}}, "current_database()"}, // name of current database (called "catalog" in the SQL standard) + {"pg_catalog", "current_query", {nullptr}, {{nullptr, nullptr}}, "current_query()"}, // the currently executing query (NULL if not inside a plpgsql function) + {"pg_catalog", "current_schema", {nullptr}, {{nullptr, nullptr}}, "current_schema()"}, // name of current schema + {"pg_catalog", "current_schemas", {"include_implicit"}, {{nullptr, nullptr}}, "current_schemas(include_implicit)"}, // names of schemas in search path // privilege functions - // {"has_any_column_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for any column of table - {"pg_catalog", "has_any_column_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for any column of table - // {"has_column_privilege", {"user", "table", "column", "privilege", nullptr}, "true"}, //boolean //does user have privilege for column - {"pg_catalog", "has_column_privilege", {"table", "column", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for column - // {"has_database_privilege", {"user", "database", "privilege", nullptr}, "true"}, //boolean //does user have privilege for database - {"pg_catalog", "has_database_privilege", {"database", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for database - // {"has_foreign_data_wrapper_privilege", {"user", "fdw", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign-data wrapper - {"pg_catalog", "has_foreign_data_wrapper_privilege", {"fdw", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign-data wrapper - // {"has_function_privilege", {"user", "function", "privilege", nullptr}, "true"}, //boolean //does user have privilege for function - {"pg_catalog", "has_function_privilege", {"function", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for function - // {"has_language_privilege", {"user", "language", "privilege", nullptr}, "true"}, //boolean //does user have privilege for language - {"pg_catalog", "has_language_privilege", {"language", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for language - // {"has_schema_privilege", {"user", "schema, privilege", nullptr}, "true"}, //boolean //does user have privilege for schema - {"pg_catalog", "has_schema_privilege", {"schema", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for schema - // {"has_sequence_privilege", {"user", "sequence", "privilege", nullptr}, "true"}, //boolean //does user have privilege for sequence - {"pg_catalog", "has_sequence_privilege", {"sequence", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for sequence - // {"has_server_privilege", {"user", "server", "privilege", nullptr}, "true"}, //boolean //does user have privilege for foreign server - {"pg_catalog", "has_server_privilege", {"server", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for foreign server - // {"has_table_privilege", {"user", "table", "privilege", nullptr}, "true"}, //boolean //does user have privilege for table - {"pg_catalog", "has_table_privilege", {"table", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for table - // {"has_tablespace_privilege", {"user", "tablespace", "privilege", nullptr}, "true"}, //boolean //does user have privilege for tablespace - {"pg_catalog", "has_tablespace_privilege", {"tablespace", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for tablespace + {"pg_catalog", "has_any_column_privilege", {"table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for any column of table + {"pg_catalog", "has_any_column_privilege", {"user", "table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for any column of table + {"pg_catalog", "has_column_privilege", {"table", "column", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for column + {"pg_catalog", "has_column_privilege", {"user", "table", "column", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for column + {"pg_catalog", "has_database_privilege", {"database", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for database + {"pg_catalog", "has_database_privilege", {"user", "database", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for database + {"pg_catalog", "has_foreign_data_wrapper_privilege", {"fdw", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for foreign-data wrapper + {"pg_catalog", "has_foreign_data_wrapper_privilege", {"user", "fdw", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for foreign-data wrapper + {"pg_catalog", "has_function_privilege", {"function", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for function + {"pg_catalog", "has_function_privilege", {"user", "function", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for function + {"pg_catalog", "has_language_privilege", {"language", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for language + {"pg_catalog", "has_language_privilege", {"user", "language", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for language + {"pg_catalog", "has_schema_privilege", {"schema", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for schema + {"pg_catalog", "has_schema_privilege", {"user", "schema", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for schema + {"pg_catalog", "has_sequence_privilege", {"sequence", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for sequence + {"pg_catalog", "has_sequence_privilege", {"user", "sequence", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for sequence + {"pg_catalog", "has_server_privilege", {"server", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for foreign server + {"pg_catalog", "has_server_privilege", {"user", "server", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for foreign server + {"pg_catalog", "has_table_privilege", {"table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for table + {"pg_catalog", "has_table_privilege", {"user", "table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for table + {"pg_catalog", "has_tablespace_privilege", {"tablespace", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for tablespace + {"pg_catalog", "has_tablespace_privilege", {"user", "tablespace", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for tablespace // various postgres system functions - {"pg_catalog", "pg_get_viewdef", {"oid", nullptr}, "(select sql from duckdb_views() v where v.view_oid=oid)"}, - {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", "pretty_bool", nullptr}, "(select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000)"}, - {"pg_catalog", "pg_get_expr", {"pg_node_tree", "relation_oid", nullptr}, "pg_node_tree"}, - {"pg_catalog", "format_pg_type", {"logical_type", "type_name", nullptr}, "case upper(logical_type) when 'FLOAT' then 'float4' when 'DOUBLE' then 'float8' when 'DECIMAL' then 'numeric' when 'ENUM' then lower(type_name) when 'VARCHAR' then 'varchar' when 'BLOB' then 'bytea' when 'TIMESTAMP' then 'timestamp' when 'TIME' then 'time' when 'TIMESTAMP WITH TIME ZONE' then 'timestamptz' when 'TIME WITH TIME ZONE' then 'timetz' when 'SMALLINT' then 'int2' when 'INTEGER' then 'int4' when 'BIGINT' then 'int8' when 'BOOLEAN' then 'bool' else lower(logical_type) end"}, - {"pg_catalog", "format_type", {"type_oid", "typemod", nullptr}, "(select format_pg_type(logical_type, type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, - {"pg_catalog", "map_to_pg_oid", {"type_name", nullptr}, "case type_name when 'bool' then 16 when 'int16' then 21 when 'int' then 23 when 'bigint' then 20 when 'date' then 1082 when 'time' then 1083 when 'datetime' then 1114 when 'dec' then 1700 when 'float' then 700 when 'double' then 701 when 'bpchar' then 1043 when 'binary' then 17 when 'interval' then 1186 when 'timestamptz' then 1184 when 'timetz' then 1266 when 'bit' then 1560 when 'guid' then 2950 else null end"}, // map duckdb_oid to pg_oid. If no corresponding type, return null + {"pg_catalog", "pg_get_viewdef", {"oid", nullptr}, {{nullptr, nullptr}}, "(select sql from duckdb_views() v where v.view_oid=oid)"}, + {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", nullptr}, {{nullptr, nullptr}}, "(select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000)"}, + {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", "pretty_bool", nullptr}, {{nullptr, nullptr}}, "pg_get_constraintdef(constraint_oid)"}, + {"pg_catalog", "pg_get_expr", {"pg_node_tree", "relation_oid", nullptr}, {{nullptr, nullptr}}, "pg_node_tree"}, + {"pg_catalog", "format_pg_type", {"logical_type", "type_name", nullptr}, {{nullptr, nullptr}}, "case upper(logical_type) when 'FLOAT' then 'float4' when 'DOUBLE' then 'float8' when 'DECIMAL' then 'numeric' when 'ENUM' then lower(type_name) when 'VARCHAR' then 'varchar' when 'BLOB' then 'bytea' when 'TIMESTAMP' then 'timestamp' when 'TIME' then 'time' when 'TIMESTAMP WITH TIME ZONE' then 'timestamptz' when 'TIME WITH TIME ZONE' then 'timetz' when 'SMALLINT' then 'int2' when 'INTEGER' then 'int4' when 'BIGINT' then 'int8' when 'BOOLEAN' then 'bool' else lower(logical_type) end"}, + {"pg_catalog", "format_type", {"type_oid", "typemod", nullptr}, {{nullptr, nullptr}}, "(select format_pg_type(logical_type, type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, + {"pg_catalog", "map_to_pg_oid", {"type_name", nullptr}, {{nullptr, nullptr}}, "case type_name when 'bool' then 16 when 'int16' then 21 when 'int' then 23 when 'bigint' then 20 when 'date' then 1082 when 'time' then 1083 when 'datetime' then 1114 when 'dec' then 1700 when 'float' then 700 when 'double' then 701 when 'bpchar' then 1043 when 'binary' then 17 when 'interval' then 1186 when 'timestamptz' then 1184 when 'timetz' then 1266 when 'bit' then 1560 when 'guid' then 2950 else null end"}, // map duckdb_oid to pg_oid. If no corresponding type, return null - {"pg_catalog", "pg_has_role", {"user", "role", "privilege", nullptr}, "true"}, //boolean //does user have privilege for role - {"pg_catalog", "pg_has_role", {"role", "privilege", nullptr}, "true"}, //boolean //does current user have privilege for role + {"pg_catalog", "pg_has_role", {"user", "role", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for role + {"pg_catalog", "pg_has_role", {"role", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for role - {"pg_catalog", "col_description", {"table_oid", "column_number", nullptr}, "NULL"}, // get comment for a table column - {"pg_catalog", "obj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a database object - {"pg_catalog", "shobj_description", {"object_oid", "catalog_name", nullptr}, "NULL"}, // get comment for a shared database object + {"pg_catalog", "col_description", {"table_oid", "column_number", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a table column + {"pg_catalog", "obj_description", {"object_oid", "catalog_name", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a database object + {"pg_catalog", "shobj_description", {"object_oid", "catalog_name", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a shared database object // visibility functions - {"pg_catalog", "pg_collation_is_visible", {"collation_oid", nullptr}, "true"}, - {"pg_catalog", "pg_conversion_is_visible", {"conversion_oid", nullptr}, "true"}, - {"pg_catalog", "pg_function_is_visible", {"function_oid", nullptr}, "true"}, - {"pg_catalog", "pg_opclass_is_visible", {"opclass_oid", nullptr}, "true"}, - {"pg_catalog", "pg_operator_is_visible", {"operator_oid", nullptr}, "true"}, - {"pg_catalog", "pg_opfamily_is_visible", {"opclass_oid", nullptr}, "true"}, - {"pg_catalog", "pg_table_is_visible", {"table_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_config_is_visible", {"config_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_dict_is_visible", {"dict_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_parser_is_visible", {"parser_oid", nullptr}, "true"}, - {"pg_catalog", "pg_ts_template_is_visible", {"template_oid", nullptr}, "true"}, - {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, "true"}, - - {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, "format_bytes(bytes)"}, - - {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, - {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, "round_even(x, n)"}, - {DEFAULT_SCHEMA, "nullif", {"a", "b", nullptr}, "CASE WHEN a=b THEN NULL ELSE a END"}, - {DEFAULT_SCHEMA, "list_append", {"l", "e", nullptr}, "list_concat(l, list_value(e))"}, - {DEFAULT_SCHEMA, "array_append", {"arr", "el", nullptr}, "list_append(arr, el)"}, - {DEFAULT_SCHEMA, "list_prepend", {"e", "l", nullptr}, "list_concat(list_value(e), l)"}, - {DEFAULT_SCHEMA, "array_prepend", {"el", "arr", nullptr}, "list_prepend(el, arr)"}, - {DEFAULT_SCHEMA, "array_pop_back", {"arr", nullptr}, "arr[:LEN(arr)-1]"}, - {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, "arr[2:]"}, - {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, "list_concat(arr, list_value(e))"}, - {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, "list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, - {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, "unnest(generate_series(1, array_length(arr, dim)))"}, - {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, "floor(x/y)"}, - {DEFAULT_SCHEMA, "fmod", {"x", "y", nullptr}, "(x-y*floor(x/y))"}, - {DEFAULT_SCHEMA, "count_if", {"l", nullptr}, "sum(if(l, 1, 0))"}, - {DEFAULT_SCHEMA, "split_part", {"string", "delimiter", "position", nullptr}, "coalesce(string_split(string, delimiter)[position],'')"}, - {DEFAULT_SCHEMA, "geomean", {"x", nullptr}, "exp(avg(ln(x)))"}, - {DEFAULT_SCHEMA, "geometric_mean", {"x", nullptr}, "geomean(x)"}, - - {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, "l[:-:-1]"}, - {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, "list_reverse(l)"}, + {"pg_catalog", "pg_collation_is_visible", {"collation_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_conversion_is_visible", {"conversion_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_function_is_visible", {"function_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_opclass_is_visible", {"opclass_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_operator_is_visible", {"operator_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_opfamily_is_visible", {"opclass_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_table_is_visible", {"table_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_ts_config_is_visible", {"config_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_ts_dict_is_visible", {"dict_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_ts_parser_is_visible", {"parser_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_ts_template_is_visible", {"template_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, {{nullptr, nullptr}}, "true"}, + + {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, {{nullptr, nullptr}}, "format_bytes(bytes)"}, + + {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, {{nullptr, nullptr}}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, + {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, {{nullptr, nullptr}}, "round_even(x, n)"}, + {DEFAULT_SCHEMA, "nullif", {"a", "b", nullptr}, {{nullptr, nullptr}}, "CASE WHEN a=b THEN NULL ELSE a END"}, + {DEFAULT_SCHEMA, "list_append", {"l", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(l, list_value(e))"}, + {DEFAULT_SCHEMA, "array_append", {"arr", "el", nullptr}, {{nullptr, nullptr}}, "list_append(arr, el)"}, + {DEFAULT_SCHEMA, "list_prepend", {"e", "l", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), l)"}, + {DEFAULT_SCHEMA, "array_prepend", {"el", "arr", nullptr}, {{nullptr, nullptr}}, "list_prepend(el, arr)"}, + {DEFAULT_SCHEMA, "array_pop_back", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[:LEN(arr)-1]"}, + {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[2:]"}, + {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(arr, list_value(e))"}, + {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), arr)"}, + {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + // Test default parameters + {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + + {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, {{nullptr, nullptr}}, "unnest(generate_series(1, array_length(arr, dim)))"}, + {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, {{nullptr, nullptr}}, "floor(x/y)"}, + {DEFAULT_SCHEMA, "fmod", {"x", "y", nullptr}, {{nullptr, nullptr}}, "(x-y*floor(x/y))"}, + {DEFAULT_SCHEMA, "count_if", {"l", nullptr}, {{nullptr, nullptr}}, "sum(if(l, 1, 0))"}, + {DEFAULT_SCHEMA, "split_part", {"string", "delimiter", "position", nullptr}, {{nullptr, nullptr}}, "coalesce(string_split(string, delimiter)[position],'')"}, + {DEFAULT_SCHEMA, "geomean", {"x", nullptr}, {{nullptr, nullptr}}, "exp(avg(ln(x)))"}, + {DEFAULT_SCHEMA, "geometric_mean", {"x", nullptr}, {{nullptr, nullptr}}, "geomean(x)"}, + + {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "l[:-:-1]"}, + {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "list_reverse(l)"}, // FIXME implement as actual function if we encounter a lot of performance issues. Complexity now: n * m, with hashing possibly n + m - {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, "list_filter(list_distinct(l1), (variable_intersect) -> list_contains(l2, variable_intersect))"}, - {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, "list_intersect(l1, l2)"}, - - {DEFAULT_SCHEMA, "list_has_any", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_filter(l1, (variable_has_any) -> list_contains(l2, variable_has_any))) > 0 THEN true ELSE false END"}, - {DEFAULT_SCHEMA, "array_has_any", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, - {DEFAULT_SCHEMA, "&&", {"l1", "l2", nullptr}, "list_has_any(l1, l2)" }, // "&&" is the operator for "list_has_any - - {DEFAULT_SCHEMA, "list_has_all", {"l1", "l2", nullptr}, "CASE WHEN l1 IS NULL THEN NULL WHEN l2 IS NULL THEN NULL WHEN len(list_filter(l2, (variable_has_all) -> list_contains(l1, variable_has_all))) = len(list_filter(l2, variable_has_all -> variable_has_all IS NOT NULL)) THEN true ELSE false END"}, - {DEFAULT_SCHEMA, "array_has_all", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, - {DEFAULT_SCHEMA, "@>", {"l1", "l2", nullptr}, "list_has_all(l1, l2)" }, // "@>" is the operator for "list_has_all - {DEFAULT_SCHEMA, "<@", {"l1", "l2", nullptr}, "list_has_all(l2, l1)" }, // "<@" is the operator for "list_has_all + {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_filter(list_distinct(l1), (variable_intersect) -> list_contains(l2, variable_intersect))"}, + {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_intersect(l1, l2)"}, // algebraic list aggregates - {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, "list_aggr(l, 'avg')"}, - {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, "list_aggr(l, 'var_samp')"}, - {DEFAULT_SCHEMA, "list_var_pop", {"l", nullptr}, "list_aggr(l, 'var_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_pop", {"l", nullptr}, "list_aggr(l, 'stddev_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_samp", {"l", nullptr}, "list_aggr(l, 'stddev_samp')"}, - {DEFAULT_SCHEMA, "list_sem", {"l", nullptr}, "list_aggr(l, 'sem')"}, + {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'avg')"}, + {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'var_samp')"}, + {DEFAULT_SCHEMA, "list_var_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'var_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'stddev_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_samp", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'stddev_samp')"}, + {DEFAULT_SCHEMA, "list_sem", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'sem')"}, // distributive list aggregates - {DEFAULT_SCHEMA, "list_approx_count_distinct", {"l", nullptr}, "list_aggr(l, 'approx_count_distinct')"}, - {DEFAULT_SCHEMA, "list_bit_xor", {"l", nullptr}, "list_aggr(l, 'bit_xor')"}, - {DEFAULT_SCHEMA, "list_bit_or", {"l", nullptr}, "list_aggr(l, 'bit_or')"}, - {DEFAULT_SCHEMA, "list_bit_and", {"l", nullptr}, "list_aggr(l, 'bit_and')"}, - {DEFAULT_SCHEMA, "list_bool_and", {"l", nullptr}, "list_aggr(l, 'bool_and')"}, - {DEFAULT_SCHEMA, "list_bool_or", {"l", nullptr}, "list_aggr(l, 'bool_or')"}, - {DEFAULT_SCHEMA, "list_count", {"l", nullptr}, "list_aggr(l, 'count')"}, - {DEFAULT_SCHEMA, "list_entropy", {"l", nullptr}, "list_aggr(l, 'entropy')"}, - {DEFAULT_SCHEMA, "list_last", {"l", nullptr}, "list_aggr(l, 'last')"}, - {DEFAULT_SCHEMA, "list_first", {"l", nullptr}, "list_aggr(l, 'first')"}, - {DEFAULT_SCHEMA, "list_any_value", {"l", nullptr}, "list_aggr(l, 'any_value')"}, - {DEFAULT_SCHEMA, "list_kurtosis", {"l", nullptr}, "list_aggr(l, 'kurtosis')"}, - {DEFAULT_SCHEMA, "list_kurtosis_pop", {"l", nullptr}, "list_aggr(l, 'kurtosis_pop')"}, - {DEFAULT_SCHEMA, "list_min", {"l", nullptr}, "list_aggr(l, 'min')"}, - {DEFAULT_SCHEMA, "list_max", {"l", nullptr}, "list_aggr(l, 'max')"}, - {DEFAULT_SCHEMA, "list_product", {"l", nullptr}, "list_aggr(l, 'product')"}, - {DEFAULT_SCHEMA, "list_skewness", {"l", nullptr}, "list_aggr(l, 'skewness')"}, - {DEFAULT_SCHEMA, "list_sum", {"l", nullptr}, "list_aggr(l, 'sum')"}, - {DEFAULT_SCHEMA, "list_string_agg", {"l", nullptr}, "list_aggr(l, 'string_agg')"}, + {DEFAULT_SCHEMA, "list_approx_count_distinct", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'approx_count_distinct')"}, + {DEFAULT_SCHEMA, "list_bit_xor", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_xor')"}, + {DEFAULT_SCHEMA, "list_bit_or", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_or')"}, + {DEFAULT_SCHEMA, "list_bit_and", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_and')"}, + {DEFAULT_SCHEMA, "list_bool_and", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bool_and')"}, + {DEFAULT_SCHEMA, "list_bool_or", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bool_or')"}, + {DEFAULT_SCHEMA, "list_count", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'count')"}, + {DEFAULT_SCHEMA, "list_entropy", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'entropy')"}, + {DEFAULT_SCHEMA, "list_last", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'last')"}, + {DEFAULT_SCHEMA, "list_first", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'first')"}, + {DEFAULT_SCHEMA, "list_any_value", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'any_value')"}, + {DEFAULT_SCHEMA, "list_kurtosis", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'kurtosis')"}, + {DEFAULT_SCHEMA, "list_kurtosis_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'kurtosis_pop')"}, + {DEFAULT_SCHEMA, "list_min", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'min')"}, + {DEFAULT_SCHEMA, "list_max", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'max')"}, + {DEFAULT_SCHEMA, "list_product", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'product')"}, + {DEFAULT_SCHEMA, "list_skewness", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'skewness')"}, + {DEFAULT_SCHEMA, "list_sum", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'sum')"}, + {DEFAULT_SCHEMA, "list_string_agg", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'string_agg')"}, // holistic list aggregates - {DEFAULT_SCHEMA, "list_mode", {"l", nullptr}, "list_aggr(l, 'mode')"}, - {DEFAULT_SCHEMA, "list_median", {"l", nullptr}, "list_aggr(l, 'median')"}, - {DEFAULT_SCHEMA, "list_mad", {"l", nullptr}, "list_aggr(l, 'mad')"}, + {DEFAULT_SCHEMA, "list_mode", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'mode')"}, + {DEFAULT_SCHEMA, "list_median", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'median')"}, + {DEFAULT_SCHEMA, "list_mad", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'mad')"}, // nested list aggregates - {DEFAULT_SCHEMA, "list_histogram", {"l", nullptr}, "list_aggr(l, 'histogram')"}, + {DEFAULT_SCHEMA, "list_histogram", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'histogram')"}, + + // map functions + {DEFAULT_SCHEMA, "map_contains_entry", {"map", "key", "value"}, {{nullptr, nullptr}}, "contains(map_entries(map), {'key': key, 'value': value})"}, + {DEFAULT_SCHEMA, "map_contains_value", {"map", "value", nullptr}, {{nullptr, nullptr}}, "contains(map_values(map), value)"}, // date functions - {DEFAULT_SCHEMA, "date_add", {"date", "interval", nullptr}, "date + interval"}, + {DEFAULT_SCHEMA, "date_add", {"date", "interval", nullptr}, {{nullptr, nullptr}}, "date + interval"}, // regexp functions - {DEFAULT_SCHEMA, "regexp_split_to_table", {"text", "pattern", nullptr}, "unnest(string_split_regex(text, pattern))"}, + {DEFAULT_SCHEMA, "regexp_split_to_table", {"text", "pattern", nullptr}, {{nullptr, nullptr}}, "unnest(string_split_regex(text, pattern))"}, - // storage helper functions - {DEFAULT_SCHEMA, "get_block_size", {"db_name"}, "(SELECT block_size FROM pragma_database_size() WHERE database_name = db_name)"}, + // storage helper functions + {DEFAULT_SCHEMA, "get_block_size", {"db_name"}, {{nullptr, nullptr}}, "(SELECT block_size FROM pragma_database_size() WHERE database_name = db_name)"}, - {nullptr, nullptr, {nullptr}, nullptr} + // string functions + {DEFAULT_SCHEMA, "md5_number_upper", {"param"}, {{nullptr, nullptr}}, "((md5_number(param)::bit::varchar)[65:])::bit::uint64"}, + {DEFAULT_SCHEMA, "md5_number_lower", {"param"}, {{nullptr, nullptr}}, "((md5_number(param)::bit::varchar)[:64])::bit::uint64"}, + + {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} }; -unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(const DefaultMacro &default_macro, unique_ptr function) { - for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { - function->parameters.push_back( - make_uniq(default_macro.parameters[param_idx])); - } +unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(const DefaultMacro &default_macro) { + return CreateInternalMacroInfo(array_ptr(default_macro)); +} - auto type = function->type == MacroType::TABLE_MACRO ? CatalogType::TABLE_MACRO_ENTRY : CatalogType::MACRO_ENTRY; + +unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(array_ptr macros) { + auto type = CatalogType::MACRO_ENTRY; auto bind_info = make_uniq(type); - bind_info->schema = default_macro.schema; - bind_info->name = default_macro.name; + for(auto &default_macro : macros) { + // parse the expression + auto expressions = Parser::ParseExpressionList(default_macro.macro); + D_ASSERT(expressions.size() == 1); + + auto function = make_uniq(std::move(expressions[0])); + for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { + function->parameters.push_back( + make_uniq(default_macro.parameters[param_idx])); + } + for (idx_t named_idx = 0; default_macro.named_parameters[named_idx].name != nullptr; named_idx++) { + auto expr_list = Parser::ParseExpressionList(default_macro.named_parameters[named_idx].default_value); + if (expr_list.size() != 1) { + throw InternalException("Expected a single expression"); + } + function->default_parameters.insert( + make_pair(default_macro.named_parameters[named_idx].name, std::move(expr_list[0]))); + } + D_ASSERT(function->type == MacroType::SCALAR_MACRO); + bind_info->macros.push_back(std::move(function)); + } + bind_info->schema = macros[0].schema; + bind_info->name = macros[0].name; bind_info->temporary = true; bind_info->internal = true; - bind_info->function = std::move(function); return bind_info; - -} - -unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(const DefaultMacro &default_macro) { - // parse the expression - auto expressions = Parser::ParseExpressionList(default_macro.macro); - D_ASSERT(expressions.size() == 1); - - auto result = make_uniq(std::move(expressions[0])); - return CreateInternalTableMacroInfo(default_macro, std::move(result)); } -unique_ptr DefaultFunctionGenerator::CreateInternalTableMacroInfo(const DefaultMacro &default_macro) { - Parser parser; - parser.ParseQuery(default_macro.macro); - D_ASSERT(parser.statements.size() == 1); - D_ASSERT(parser.statements[0]->type == StatementType::SELECT_STATEMENT); - - auto &select = parser.statements[0]->Cast(); - auto result = make_uniq(std::move(select.node)); - return CreateInternalTableMacroInfo(default_macro, std::move(result)); +static bool DefaultFunctionMatches(const DefaultMacro ¯o, const string &schema, const string &name) { + return macro.schema == schema && macro.name == name; } static unique_ptr GetDefaultFunction(const string &input_schema, const string &input_name) { auto schema = StringUtil::Lower(input_schema); auto name = StringUtil::Lower(input_name); for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { - if (internal_macros[index].schema == schema && internal_macros[index].name == name) { - return DefaultFunctionGenerator::CreateInternalMacroInfo(internal_macros[index]); + if (DefaultFunctionMatches(internal_macros[index], schema, name)) { + // found the function! keep on iterating to find all overloads + idx_t overload_count; + for(overload_count = 1; internal_macros[index + overload_count].name; overload_count++) { + if (!DefaultFunctionMatches(internal_macros[index + overload_count], schema, name)) { + break; + } + } + return DefaultFunctionGenerator::CreateInternalMacroInfo(array_ptr(internal_macros + index, overload_count)); } } return nullptr; diff --git a/src/duckdb/src/catalog/default/default_generator.cpp b/src/duckdb/src/catalog/default/default_generator.cpp new file mode 100644 index 00000000..2fbb2b64 --- /dev/null +++ b/src/duckdb/src/catalog/default/default_generator.cpp @@ -0,0 +1,24 @@ +#include "duckdb/catalog/default/default_generator.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" + +namespace duckdb { + +DefaultGenerator::DefaultGenerator(Catalog &catalog) : catalog(catalog), created_all_entries(false) { +} +DefaultGenerator::~DefaultGenerator() { +} + +unique_ptr DefaultGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { + throw InternalException("CreateDefaultEntry with ClientContext called but not supported in this generator"); +} + +unique_ptr DefaultGenerator::CreateDefaultEntry(CatalogTransaction transaction, + const string &entry_name) { + if (!transaction.context) { + // no context - cannot create default entry + return nullptr; + } + return CreateDefaultEntry(*transaction.context, entry_name); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_schemas.cpp b/src/duckdb/src/catalog/default/default_schemas.cpp index 72a95da7..64aaf56d 100644 --- a/src/duckdb/src/catalog/default/default_schemas.cpp +++ b/src/duckdb/src/catalog/default/default_schemas.cpp @@ -11,7 +11,7 @@ struct DefaultSchema { static const DefaultSchema internal_schemas[] = {{"information_schema"}, {"pg_catalog"}, {nullptr}}; -static bool GetDefaultSchema(const string &input_schema) { +bool DefaultSchemaGenerator::IsDefaultSchema(const string &input_schema) { auto schema = StringUtil::Lower(input_schema); for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { if (internal_schemas[index].name == schema) { @@ -24,8 +24,9 @@ static bool GetDefaultSchema(const string &input_schema) { DefaultSchemaGenerator::DefaultSchemaGenerator(Catalog &catalog) : DefaultGenerator(catalog) { } -unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - if (GetDefaultSchema(entry_name)) { +unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(CatalogTransaction transaction, + const string &entry_name) { + if (IsDefaultSchema(entry_name)) { CreateSchemaInfo info; info.schema = StringUtil::Lower(entry_name); info.internal = true; diff --git a/src/duckdb/src/catalog/default/default_table_functions.cpp b/src/duckdb/src/catalog/default/default_table_functions.cpp new file mode 100644 index 00000000..b0755c83 --- /dev/null +++ b/src/duckdb/src/catalog/default/default_table_functions.cpp @@ -0,0 +1,148 @@ +#include "duckdb/catalog/default/default_table_functions.hpp" +#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/parsed_data/create_macro_info.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/function/table_macro_function.hpp" + +namespace duckdb { + +// clang-format off +static const DefaultTableMacro internal_table_macros[] = { + {DEFAULT_SCHEMA, "histogram_values", {"source", "col_name", nullptr}, {{"bin_count", "10"}, {"technique", "'auto'"}, {nullptr, nullptr}}, R"( +WITH bins AS ( + SELECT + CASE + WHEN (NOT (can_cast_implicitly(MIN(col_name), NULL::BIGINT) OR + can_cast_implicitly(MIN(col_name), NULL::DOUBLE) OR + can_cast_implicitly(MIN(col_name), NULL::TIMESTAMP)) AND technique='auto') + OR technique='sample' + THEN + approx_top_k(col_name, bin_count) + WHEN technique='equi-height' + THEN + quantile(col_name, [x / bin_count::DOUBLE for x in generate_series(1, bin_count)]) + WHEN technique='equi-width' + THEN + equi_width_bins(MIN(col_name), MAX(col_name), bin_count, false) + WHEN technique='equi-width-nice' OR technique='auto' + THEN + equi_width_bins(MIN(col_name), MAX(col_name), bin_count, true) + ELSE + error(concat('Unrecognized technique ', technique)) + END AS bins + FROM query_table(source::VARCHAR) + ) +SELECT UNNEST(map_keys(histogram)) AS bin, UNNEST(map_values(histogram)) AS count +FROM ( + SELECT CASE + WHEN (NOT (can_cast_implicitly(MIN(col_name), NULL::BIGINT) OR + can_cast_implicitly(MIN(col_name), NULL::DOUBLE) OR + can_cast_implicitly(MIN(col_name), NULL::TIMESTAMP)) AND technique='auto') + OR technique='sample' + THEN + histogram_exact(col_name, bins) + ELSE + histogram(col_name, bins) + END AS histogram + FROM query_table(source::VARCHAR), bins +); +)"}, + {DEFAULT_SCHEMA, "histogram", {"source", "col_name", nullptr}, {{"bin_count", "10"}, {"technique", "'auto'"}, {nullptr, nullptr}}, R"( +SELECT + CASE + WHEN is_histogram_other_bin(bin) + THEN '(other values)' + WHEN (NOT (can_cast_implicitly(bin, NULL::BIGINT) OR + can_cast_implicitly(bin, NULL::DOUBLE) OR + can_cast_implicitly(bin, NULL::TIMESTAMP)) AND technique='auto') + OR technique='sample' + THEN bin::VARCHAR + WHEN row_number() over () = 1 + THEN concat('x <= ', bin::VARCHAR) + ELSE concat(lag(bin::VARCHAR) over (), ' < x <= ', bin::VARCHAR) + END AS bin, + count, + bar(count, 0, max(count) over ()) AS bar +FROM histogram_values(source, col_name, bin_count := bin_count, technique := technique); +)"}, + {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} + }; +// clang-format on + +DefaultTableFunctionGenerator::DefaultTableFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema) + : DefaultGenerator(catalog), schema(schema) { +} + +unique_ptr +DefaultTableFunctionGenerator::CreateInternalTableMacroInfo(const DefaultTableMacro &default_macro, + unique_ptr function) { + for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { + function->parameters.push_back(make_uniq(default_macro.parameters[param_idx])); + } + for (idx_t named_idx = 0; default_macro.named_parameters[named_idx].name != nullptr; named_idx++) { + auto expr_list = Parser::ParseExpressionList(default_macro.named_parameters[named_idx].default_value); + if (expr_list.size() != 1) { + throw InternalException("Expected a single expression"); + } + function->default_parameters.insert( + make_pair(default_macro.named_parameters[named_idx].name, std::move(expr_list[0]))); + } + + auto type = CatalogType::TABLE_MACRO_ENTRY; + auto bind_info = make_uniq(type); + bind_info->schema = default_macro.schema; + bind_info->name = default_macro.name; + bind_info->temporary = true; + bind_info->internal = true; + bind_info->macros.push_back(std::move(function)); + return bind_info; +} + +unique_ptr +DefaultTableFunctionGenerator::CreateTableMacroInfo(const DefaultTableMacro &default_macro) { + Parser parser; + parser.ParseQuery(default_macro.macro); + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw InternalException("Expected a single select statement in CreateTableMacroInfo internal"); + } + auto node = std::move(parser.statements[0]->Cast().node); + + auto result = make_uniq(std::move(node)); + return CreateInternalTableMacroInfo(default_macro, std::move(result)); +} + +static unique_ptr GetDefaultTableFunction(const string &input_schema, const string &input_name) { + auto schema = StringUtil::Lower(input_schema); + auto name = StringUtil::Lower(input_name); + for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { + if (internal_table_macros[index].schema == schema && internal_table_macros[index].name == name) { + return DefaultTableFunctionGenerator::CreateTableMacroInfo(internal_table_macros[index]); + } + } + return nullptr; +} + +unique_ptr DefaultTableFunctionGenerator::CreateDefaultEntry(ClientContext &context, + const string &entry_name) { + auto info = GetDefaultTableFunction(schema.name, entry_name); + if (info) { + return make_uniq_base(catalog, schema, info->Cast()); + } + return nullptr; +} + +vector DefaultTableFunctionGenerator::GetDefaultEntries() { + vector result; + for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { + if (StringUtil::Lower(internal_table_macros[index].name) != internal_table_macros[index].name) { + throw InternalException("Default macro name %s should be lowercase", internal_table_macros[index].name); + } + if (internal_table_macros[index].schema == schema.name) { + result.emplace_back(internal_table_macros[index].name); + } + } + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_views.cpp b/src/duckdb/src/catalog/default/default_views.cpp index aeec5865..71869647 100644 --- a/src/duckdb/src/catalog/default/default_views.cpp +++ b/src/duckdb/src/catalog/default/default_views.cpp @@ -50,9 +50,13 @@ static const DefaultView internal_views[] = { {"information_schema", "schemata", "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL::VARCHAR default_character_set_catalog, NULL::VARCHAR default_character_set_schema, NULL::VARCHAR default_character_set_name, sql sql_path FROM duckdb_schemas()"}, {"information_schema", "tables", "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL::VARCHAR self_referencing_column_name, NULL::VARCHAR reference_generation, NULL::VARCHAR user_defined_type_catalog, NULL::VARCHAR user_defined_type_schema, NULL::VARCHAR user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN 'PRESERVE' ELSE NULL END commit_action, comment AS TABLE_COMMENT FROM duckdb_tables() UNION ALL SELECT database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL commit_action, comment AS TABLE_COMMENT FROM duckdb_views;"}, {"information_schema", "character_sets", "SELECT NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, 'UTF8' character_set_name, 'UCS' character_repertoire, 'UTF8' form_of_use, current_database() default_collate_catalog, 'pg_catalog' default_collate_schema, 'ucs_basic' default_collate_name;"}, - {"information_schema", "referential_constraints", "SELECT f.database_name constraint_catalog, f.schema_name constraint_schema, concat(f.table_name, '_', f.source, '_fkey') constraint_name, current_database() unique_constraint_catalog, c.schema_name unique_constraint_schema, concat(c.table_name, '_', f.target_column, '_', CASE WHEN c.constraint_type == 'UNIQUE' THEN 'key' ELSE 'pkey' END) unique_constraint_name, 'NONE' match_option, 'NO ACTION' update_rule, 'NO ACTION' delete_rule FROM duckdb_constraints() c JOIN (SELECT *, name_extract['source'] as source, name_extract['target'] as target, name_extract['target_column'] as target_column FROM (SELECT *, regexp_extract(constraint_text, 'FOREIGN KEY \\(([a-zA-Z_0-9]+)\\) REFERENCES ([a-zA-Z_0-9]+)\\(([a-zA-Z_0-9]+)\\)', ['source', 'target', 'target_column']) name_extract FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY')) f ON name_extract['target'] = c.table_name AND (c.constraint_type = 'UNIQUE' OR c.constraint_type = 'PRIMARY KEY')"}, - {"information_schema", "key_column_usage", "SELECT current_database() constraint_catalog, schema_name constraint_schema, concat(table_name, '_', constraint_column_names[1], CASE constraint_type WHEN 'FOREIGN KEY' THEN '_fkey' WHEN 'PRIMARY KEY' THEN '_pkey' ELSE '_key' END) constraint_name, current_database() table_catalog, schema_name table_schema, table_name, constraint_column_names[1] column_name, 1 ordinal_position, CASE constraint_type WHEN 'FOREIGN KEY' THEN 1 ELSE NULL END position_in_unique_constraint FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY' OR constraint_type = 'PRIMARY KEY' OR constraint_type = 'UNIQUE';"}, - {"information_schema", "table_constraints", "SELECT current_database() constraint_catalog, schema_name constraint_schema, concat(table_name, '_', CASE WHEN length(constraint_column_names) > 1 THEN NULL ELSE constraint_column_names[1] || '_' END, CASE constraint_type WHEN 'FOREIGN KEY' THEN 'fkey' WHEN 'PRIMARY KEY' THEN 'pkey' WHEN 'UNIQUE' THEN 'key' WHEN 'CHECK' THEN 'check' WHEN 'NOT NULL' THEN 'not_null' END) constraint_name, current_database() table_catalog, schema_name table_schema, table_name, CASE constraint_type WHEN 'NOT NULL' THEN 'CHECK' ELSE constraint_type END constraint_type, 'NO' is_deferrable, 'NO' initially_deferred, 'YES' enforced, 'YES' nulls_distinct FROM duckdb_constraints() WHERE constraint_type = 'PRIMARY KEY' OR constraint_type = 'FOREIGN KEY' OR constraint_type = 'UNIQUE' OR constraint_type = 'CHECK' OR constraint_type = 'NOT NULL';"}, + {"information_schema", "referential_constraints", "SELECT f.database_name constraint_catalog, f.schema_name constraint_schema, f.constraint_name constraint_name, c.database_name unique_constraint_catalog, c.schema_name unique_constraint_schema, c.constraint_name unique_constraint_name, 'NONE' match_option, 'NO ACTION' update_rule, 'NO ACTION' delete_rule FROM duckdb_constraints() c, duckdb_constraints() f WHERE f.constraint_type = 'FOREIGN KEY' AND (c.constraint_type = 'UNIQUE' OR c.constraint_type = 'PRIMARY KEY') AND f.database_oid = c.database_oid AND f.schema_oid = c.schema_oid AND lower(f.referenced_table) = lower(c.table_name) AND [lower(x) for x in f.referenced_column_names] = [lower(x) for x in c.constraint_column_names]"}, + {"information_schema", "key_column_usage", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, UNNEST(constraint_column_names) column_name, UNNEST(generate_series(1, len(constraint_column_names))) ordinal_position, CASE constraint_type WHEN 'FOREIGN KEY' THEN 1 ELSE NULL END position_in_unique_constraint FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY' OR constraint_type = 'PRIMARY KEY' OR constraint_type = 'UNIQUE';"}, + {"information_schema", "table_constraints", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, CASE constraint_type WHEN 'NOT NULL' THEN 'CHECK' ELSE constraint_type END constraint_type, 'NO' is_deferrable, 'NO' initially_deferred, 'YES' enforced, 'YES' nulls_distinct FROM duckdb_constraints() WHERE constraint_type = 'PRIMARY KEY' OR constraint_type = 'FOREIGN KEY' OR constraint_type = 'UNIQUE' OR constraint_type = 'CHECK' OR constraint_type = 'NOT NULL';"}, + {"information_schema", "constraint_column_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, column_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type, constraint_text FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type NOT IN ('NOT NULL') );"}, + {"information_schema", "constraint_table_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type FROM duckdb_constraints() WHERE constraint_type NOT IN ('NOT NULL');"}, + {"information_schema", "check_constraints", "SELECT database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, CASE constraint_type WHEN 'NOT NULL' THEN column_name || ' IS NOT NULL' ELSE constraint_text END AS check_clause FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type IN ('CHECK', 'NOT NULL'));"}, + {"information_schema", "views", "SELECT database_name AS table_catalog, schema_name AS table_schema, view_name AS table_name, sql AS view_definition, 'NONE' AS check_option, 'NO' AS is_updatable, 'NO' AS is_insertable_into, 'NO' AS is_trigger_updatable, 'NO' AS is_trigger_deletable, 'NO' AS is_trigger_insertable_into FROM duckdb_views();"}, {nullptr, nullptr, nullptr}}; static unique_ptr GetDefaultView(ClientContext &context, const string &input_schema, const string &input_name) { diff --git a/src/duckdb/src/catalog/duck_catalog.cpp b/src/duckdb/src/catalog/duck_catalog.cpp index 1adebff7..920740c9 100644 --- a/src/duckdb/src/catalog/duck_catalog.cpp +++ b/src/duckdb/src/catalog/duck_catalog.cpp @@ -95,7 +95,6 @@ optional_ptr DuckCatalog::CreateSchema(CatalogTransaction transact void DuckCatalog::DropSchema(CatalogTransaction transaction, DropInfo &info) { D_ASSERT(!info.name.empty()); - ModifyCatalog(); if (!schemas->DropEntry(transaction, info.name, info.cascade)) { if (info.if_not_found == OnEntryNotFound::THROW_EXCEPTION) { throw CatalogException::MissingEntry(CatalogType::SCHEMA_ENTRY, info.name, string()); @@ -156,4 +155,11 @@ void DuckCatalog::Verify() { #endif } +optional_idx DuckCatalog::GetCatalogVersion(ClientContext &context) { + auto &transaction_manager = DuckTransactionManager::Get(db); + auto transaction = GetCatalogTransaction(context); + D_ASSERT(transaction.transaction); + return transaction_manager.GetCatalogVersion(*transaction.transaction); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 65a147b0..1d002b57 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -26,7 +26,7 @@ AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *err if (!driver) { return ADBC_STATUS_INVALID_ARGUMENT; } - auto adbc_driver = reinterpret_cast(driver); + auto adbc_driver = static_cast(driver); adbc_driver->DatabaseNew = duckdb_adbc::DatabaseNew; adbc_driver->DatabaseSetOption = duckdb_adbc::DatabaseSetOption; @@ -60,13 +60,16 @@ AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *err namespace duckdb_adbc { enum class IngestionMode { CREATE = 0, APPEND = 1 }; + struct DuckDBAdbcStatementWrapper { - ::duckdb_connection connection; - ::duckdb_arrow result; - ::duckdb_prepared_statement statement; + duckdb_connection connection; + duckdb_arrow result; + duckdb_prepared_statement statement; char *ingestion_table_name; + char *db_schema; ArrowArrayStream ingestion_stream; IngestionMode ingestion_mode = IngestionMode::CREATE; + bool temporary_table = false; uint8_t *substrait_plan; uint64_t plan_length; }; @@ -99,9 +102,9 @@ static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct Ar struct DuckDBAdbcDatabaseWrapper { //! The DuckDB Database Configuration - ::duckdb_config config = nullptr; + duckdb_config config = nullptr; //! The DuckDB Database - ::duckdb_database database = nullptr; + duckdb_database database = nullptr; //! Path of Disk-Based Database or :memory: database std::string path; }; @@ -110,7 +113,6 @@ static void EmptyErrorRelease(AdbcError *error) { // The object is valid but doesn't contain any data that needs to be cleaned up // Just set the release to nullptr to indicate that it's no longer valid. error->release = nullptr; - return; } void InitializeADBCError(AdbcError *error) { @@ -124,7 +126,7 @@ void InitializeADBCError(AdbcError *error) { error->vendor_code = -1; } -AdbcStatusCode CheckResult(duckdb_state &res, AdbcError *error, const char *error_msg) { +AdbcStatusCode CheckResult(const duckdb_state &res, AdbcError *error, const char *error_msg) { if (!error) { // Error should be a non-null pointer return ADBC_STATUS_INVALID_ARGUMENT; @@ -169,8 +171,8 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *statement, const SetError(error, "Can't execute plan with size = 0"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = reinterpret_cast(statement->private_data); - wrapper->substrait_plan = (uint8_t *)malloc(sizeof(uint8_t) * (length)); + auto wrapper = static_cast(statement->private_data); + wrapper->substrait_plan = static_cast(malloc(sizeof(uint8_t) * length)); wrapper->plan_length = length; memcpy(wrapper->substrait_plan, plan, length); return ADBC_STATUS_OK; @@ -187,7 +189,7 @@ AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + auto wrapper = static_cast(database->private_data); if (strcmp(key, "path") == 0) { wrapper->path = value; return ADBC_STATUS_OK; @@ -207,7 +209,7 @@ AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *err } char *errormsg = nullptr; // TODO can we set the database path via option, too? Does not look like it... - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + auto wrapper = static_cast(database->private_data); auto res = duckdb_open_ext(wrapper->path.c_str(), &wrapper->database, wrapper->config, &errormsg); auto adbc_result = CheckResult(res, error, errormsg); if (errormsg) { @@ -219,7 +221,7 @@ AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *err AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { if (database && database->private_data) { - auto wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + auto wrapper = static_cast(database->private_data); duckdb_close(&wrapper->database); duckdb_destroy_config(&wrapper->config); @@ -290,7 +292,8 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection *connection, const char SetError(error, "Connection is not set"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto conn = (duckdb::Connection *)connection->private_data; + + auto conn = static_cast(connection->private_data); if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { if (conn->HasActiveTransaction()) { @@ -343,7 +346,7 @@ AdbcStatusCode ConnectionCommit(struct AdbcConnection *connection, struct AdbcEr SetError(error, "Connection is not set"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto conn = (duckdb::Connection *)connection->private_data; + auto conn = static_cast(connection->private_data); if (!conn->HasActiveTransaction()) { SetError(error, "No active transaction, cannot commit"); return ADBC_STATUS_INVALID_STATE; @@ -361,7 +364,7 @@ AdbcStatusCode ConnectionRollback(struct AdbcConnection *connection, struct Adbc SetError(error, "Connection is not set"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto conn = (duckdb::Connection *)connection->private_data; + auto conn = static_cast(connection->private_data); if (!conn->HasActiveTransaction()) { SetError(error, "No active transaction, cannot rollback"); return ADBC_STATUS_INVALID_STATE; @@ -416,7 +419,7 @@ AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, const uint32 } // If 'info_codes' is NULL, we should output all the info codes we recognize - size_t length = info_codes ? info_codes_length : (size_t)AdbcInfoCode::UNRECOGNIZED; + size_t length = info_codes ? info_codes_length : static_cast(AdbcInfoCode::UNRECOGNIZED); duckdb::string q = R"EOF( select @@ -498,16 +501,17 @@ AdbcStatusCode ConnectionInit(struct AdbcConnection *connection, struct AdbcData SetError(error, "Missing connection object"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto database_wrapper = (DuckDBAdbcDatabaseWrapper *)database->private_data; + auto database_wrapper = static_cast(database->private_data); connection->private_data = nullptr; - auto res = duckdb_connect(database_wrapper->database, (duckdb_connection *)&connection->private_data); + auto res = + duckdb_connect(database_wrapper->database, reinterpret_cast(&connection->private_data)); return CheckResult(res, error, "Failed to connect to Database"); } AdbcStatusCode ConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { if (connection && connection->private_data) { - duckdb_disconnect((duckdb_connection *)&connection->private_data); + duckdb_disconnect(reinterpret_cast(&connection->private_data)); connection->private_data = nullptr; } return ADBC_STATUS_OK; @@ -519,7 +523,8 @@ static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) if (!stream || !stream->private_data || !out) { return DuckDBError; } - return duckdb_query_arrow_schema((duckdb_arrow)stream->private_data, (duckdb_arrow_schema *)&out); + return duckdb_query_arrow_schema(static_cast(stream->private_data), + reinterpret_cast(&out)); } static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { @@ -528,7 +533,8 @@ static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { } out->release = nullptr; - return duckdb_query_arrow_array((duckdb_arrow)stream->private_data, (duckdb_arrow_array *)&out); + return duckdb_query_arrow_array(static_cast(stream->private_data), + reinterpret_cast(&out)); } void release(struct ArrowArrayStream *stream) { @@ -536,7 +542,7 @@ void release(struct ArrowArrayStream *stream) { return; } if (stream->private_data) { - duckdb_destroy_arrow((duckdb_arrow *)&stream->private_data); + duckdb_destroy_arrow(reinterpret_cast(&stream->private_data)); stream->private_data = nullptr; } stream->release = nullptr; @@ -558,7 +564,7 @@ duckdb::unique_ptr stream_produce(uintptr_t fac // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine auto res = duckdb::make_uniq(); - res->arrow_array_stream = *(ArrowArrayStream *)factory_ptr; + res->arrow_array_stream = *reinterpret_cast(factory_ptr); return res; } @@ -566,8 +572,9 @@ void stream_schema(ArrowArrayStream *stream, ArrowSchema &schema) { stream->get_schema(stream, &schema); } -AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input, - struct AdbcError *error, IngestionMode ingestion_mode) { +AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, const char *schema, + struct ArrowArrayStream *input, struct AdbcError *error, IngestionMode ingestion_mode, + bool temporary) { if (!connection) { SetError(error, "Missing connection object"); @@ -581,27 +588,47 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, stru SetError(error, "Missing database object name"); return ADBC_STATUS_INVALID_ARGUMENT; } + if (schema && temporary) { + // Temporary option is not supported with ADBC_INGEST_OPTION_TARGET_DB_SCHEMA or + // ADBC_INGEST_OPTION_TARGET_CATALOG + SetError(error, "Temporary option is not supported with schema"); + return ADBC_STATUS_INVALID_ARGUMENT; + } - auto cconn = (duckdb::Connection *)connection; + auto cconn = reinterpret_cast(connection); - auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), - duckdb::Value::POINTER((uintptr_t)stream_produce), - duckdb::Value::POINTER((uintptr_t)stream_schema)}); + auto arrow_scan = + cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), + duckdb::Value::POINTER(reinterpret_cast(stream_produce)), + duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); try { - if (ingestion_mode == IngestionMode::CREATE) { - // We create the table based on an Arrow Scanner - arrow_scan->Create(table_name); - } else { + switch (ingestion_mode) { + case IngestionMode::CREATE: + if (schema) { + arrow_scan->Create(schema, table_name, temporary); + } else { + arrow_scan->Create(table_name, temporary); + } + break; + case IngestionMode::APPEND: { arrow_scan->CreateView("temp_adbc_view", true, true); - auto query = duckdb::StringUtil::Format("insert into \"%s\" select * from temp_adbc_view", table_name); + std::string query; + if (schema) { + query = duckdb::StringUtil::Format("insert into \"%s.%s\" select * from temp_adbc_view", schema, + table_name); + } else { + query = duckdb::StringUtil::Format("insert into \"%s\" select * from temp_adbc_view", table_name); + } auto result = cconn->Query(query); + break; + } } // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid // double-releasing it input->release = nullptr; } catch (std::exception &ex) { if (error) { - ::duckdb::ErrorData parsed_error(ex); + duckdb::ErrorData parsed_error(ex); error->message = strdup(parsed_error.RawMessage().c_str()); } return ADBC_STATUS_INTERNAL; @@ -628,19 +655,21 @@ AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatem statement->private_data = nullptr; - auto statement_wrapper = (DuckDBAdbcStatementWrapper *)malloc(sizeof(DuckDBAdbcStatementWrapper)); + auto statement_wrapper = static_cast(malloc(sizeof(DuckDBAdbcStatementWrapper))); if (!statement_wrapper) { SetError(error, "Allocation error"); return ADBC_STATUS_INVALID_ARGUMENT; } statement->private_data = statement_wrapper; - statement_wrapper->connection = (duckdb_connection)connection->private_data; + statement_wrapper->connection = static_cast(connection->private_data); statement_wrapper->statement = nullptr; statement_wrapper->result = nullptr; statement_wrapper->ingestion_stream.release = nullptr; statement_wrapper->ingestion_table_name = nullptr; + statement_wrapper->db_schema = nullptr; statement_wrapper->substrait_plan = nullptr; + statement_wrapper->temporary_table = false; statement_wrapper->ingestion_mode = IngestionMode::CREATE; return ADBC_STATUS_OK; @@ -650,7 +679,7 @@ AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcErro if (!statement || !statement->private_data) { return ADBC_STATUS_OK; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); if (wrapper->statement) { duckdb_destroy_prepare(&wrapper->statement); wrapper->statement = nullptr; @@ -667,6 +696,10 @@ AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcErro free(wrapper->ingestion_table_name); wrapper->ingestion_table_name = nullptr; } + if (wrapper->db_schema) { + free(wrapper->db_schema); + wrapper->db_schema = nullptr; + } if (wrapper->substrait_plan) { free(wrapper->substrait_plan); wrapper->substrait_plan = nullptr; @@ -690,10 +723,10 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru SetError(error, "Missing schema object"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); // TODO: we might want to cache this, but then we need to return a deep copy anyways.., so I'm not sure if that // would be worth the extra management - auto res = duckdb_prepared_arrow_schema(wrapper->statement, (duckdb_arrow_schema *)&schema); + auto res = duckdb_prepared_arrow_schema(wrapper->statement, reinterpret_cast(&schema)); if (res != DuckDBSuccess) { return ADBC_STATUS_INVALID_ARGUMENT; } @@ -703,12 +736,13 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::unique_ptr &result, ArrowArrayStream *input, AdbcError *error) { - auto cconn = (duckdb::Connection *)connection; + auto cconn = reinterpret_cast(connection); try { - auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), - duckdb::Value::POINTER((uintptr_t)stream_produce), - duckdb::Value::POINTER((uintptr_t)stream_schema)}); + auto arrow_scan = + cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), + duckdb::Value::POINTER(reinterpret_cast(stream_produce)), + duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); result = arrow_scan->Execute(); // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid // double-releasing it @@ -735,7 +769,8 @@ static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *s statement->ingestion_stream.release = nullptr; // Ingest into a table from the bound stream - return Ingest(statement->connection, statement->ingestion_table_name, &stream, error, statement->ingestion_mode); + return Ingest(statement->connection, statement->ingestion_table_name, statement->db_schema, &stream, error, + statement->ingestion_mode, statement->temporary_table); } AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, @@ -748,7 +783,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr SetError(error, "Invalid statement object"); return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); // TODO: Set affected rows, careful with early return if (rows_affected) { @@ -767,8 +802,9 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr params.emplace_back(duckdb::Value::BLOB_RAW(plan_str)); duckdb::unique_ptr query_result; try { - query_result = - ((duckdb::Connection *)wrapper->connection)->TableFunction("from_substrait", params)->Execute(); + query_result = reinterpret_cast(wrapper->connection) + ->TableFunction("from_substrait", params) + ->Execute(); } catch (duckdb::Exception &e) { std::string error_msg = "It was not possible to execute substrait query. " + std::string(e.what()); SetError(error, error_msg); @@ -792,7 +828,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr } duckdb::unique_ptr chunk; auto prepared_statement_params = - reinterpret_cast(wrapper->statement)->statement->n_param; + reinterpret_cast(wrapper->statement)->statement->named_param_map.size(); while ((chunk = result->Fetch()) != nullptr) { if (chunk->size() == 0) { @@ -811,7 +847,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr duckdb_clear_bindings(wrapper->statement); for (idx_t col_idx = 0; col_idx < chunk->ColumnCount(); col_idx++) { auto val = chunk->GetValue(col_idx, 0); - auto duck_val = (duckdb_value)&val; + auto duck_val = reinterpret_cast(&val); auto res = duckdb_bind_value(wrapper->statement, 1 + col_idx, duck_val); if (res != DuckDBSuccess) { SetError(error, duckdb_prepare_error(wrapper->statement)); @@ -875,7 +911,7 @@ AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *statement, const char return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); auto res = duckdb_prepare(wrapper->connection, query, &wrapper->statement); auto error_msg = duckdb_prepare_error(wrapper->statement); return CheckResult(res, error, error_msg); @@ -900,7 +936,7 @@ AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); if (wrapper->ingestion_stream.release) { // Free the stream that was previously bound wrapper->ingestion_stream.release(&wrapper->ingestion_stream); @@ -924,7 +960,7 @@ AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct Arrow return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); if (wrapper->ingestion_stream.release) { // Release any resources currently held by the ingestion stream before we overwrite it wrapper->ingestion_stream.release(&wrapper->ingestion_stream); @@ -949,18 +985,41 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *k return ADBC_STATUS_INVALID_ARGUMENT; } - auto wrapper = (DuckDBAdbcStatementWrapper *)statement->private_data; + auto wrapper = static_cast(statement->private_data); if (strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { wrapper->ingestion_table_name = strdup(value); + wrapper->temporary_table = false; return ADBC_STATUS_OK; } if (strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) { - if (strcmp(value, "false") == 0) { - return ADBC_STATUS_NOT_IMPLEMENTED; + if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + if (wrapper->db_schema) { + SetError(error, "Temporary option is not supported with schema"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + wrapper->temporary_table = true; + return ADBC_STATUS_OK; + } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + wrapper->temporary_table = false; + return ADBC_STATUS_OK; + } else { + SetError( + error, + "ADBC_INGEST_OPTION_TEMPORARY, can only be ADBC_OPTION_VALUE_ENABLED or ADBC_OPTION_VALUE_DISABLED"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + + if (strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) { + if (wrapper->temporary_table) { + SetError(error, "Temporary option is not supported with schema"); + return ADBC_STATUS_INVALID_ARGUMENT; } + wrapper->db_schema = strdup(value); return ADBC_STATUS_OK; } + if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { wrapper->ingestion_mode = IngestionMode::CREATE; @@ -973,6 +1032,9 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *k return ADBC_STATUS_INVALID_ARGUMENT; } } + std::stringstream ss; + ss << "Statement Set Option " << key << " is not yet accepted by DuckDB"; + SetError(error, ss.str()); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -1264,7 +1326,7 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, struct AdbcError *error) { - const char *q = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"; + const auto q = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"; return QueryInternal(connection, out, q, error); } diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator.cpp index 772db6ee..d3ef18bb 100644 --- a/src/duckdb/src/common/allocator.cpp +++ b/src/duckdb/src/common/allocator.cpp @@ -4,6 +4,8 @@ #include "duckdb/common/atomic.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/types/timestamp.hpp" #include @@ -16,7 +18,8 @@ #endif #ifndef USE_JEMALLOC -#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) +#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) && \ + INTPTR_MAX == INT64_MAX #define USE_JEMALLOC #endif #endif @@ -25,6 +28,10 @@ #include "jemalloc_extension.hpp" #endif +#ifdef __GLIBC__ +#include +#endif + namespace duckdb { AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { @@ -128,7 +135,9 @@ data_ptr_t Allocator::AllocateData(idx_t size) { auto result = allocate_function(private_data.get(), size); #ifdef DEBUG D_ASSERT(private_data); - private_data->debug_info->AllocateData(result, size); + if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { + private_data->debug_info->AllocateData(result, size); + } #endif if (!result) { throw OutOfMemoryException("Failed to allocate block of %llu bytes (bad allocation)", size); @@ -143,7 +152,9 @@ void Allocator::FreeData(data_ptr_t pointer, idx_t size) { D_ASSERT(size > 0); #ifdef DEBUG D_ASSERT(private_data); - private_data->debug_info->FreeData(pointer, size); + if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { + private_data->debug_info->FreeData(pointer, size); + } #endif free_function(private_data.get(), pointer, size); } @@ -161,7 +172,9 @@ data_ptr_t Allocator::ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t s auto new_pointer = reallocate_function(private_data.get(), pointer, old_size, size); #ifdef DEBUG D_ASSERT(private_data); - private_data->debug_info->ReallocateData(pointer, new_pointer, old_size, size); + if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { + private_data->debug_info->ReallocateData(pointer, new_pointer, old_size, size); + } #endif if (!new_pointer) { throw OutOfMemoryException("Failed to re-allocate block of %llu bytes (bad allocation)", size); @@ -207,15 +220,67 @@ Allocator &Allocator::DefaultAllocator() { return *DefaultAllocatorReference(); } -void Allocator::ThreadFlush(idx_t threshold) { +optional_idx Allocator::DecayDelay() { #ifdef USE_JEMALLOC - JemallocExtension::ThreadFlush(threshold); + return NumericCast(JemallocExtension::DecayDelay()); +#else + return optional_idx(); +#endif +} + +bool Allocator::SupportsFlush() { +#if defined(USE_JEMALLOC) || defined(__GLIBC__) + return true; +#else + return false; +#endif +} + +static void MallocTrim(idx_t pad) { +#ifdef __GLIBC__ + static constexpr int64_t TRIM_INTERVAL_MS = 100; + static atomic LAST_TRIM_TIMESTAMP_MS {0}; + + int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); + const int64_t current_timestamp_ms = Timestamp::GetEpochMs(Timestamp::GetCurrentTimestamp()); + + if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { + return; // We trimmed less than TRIM_INTERVAL_MS ago + } + if (!std::atomic_compare_exchange_weak(&LAST_TRIM_TIMESTAMP_MS, &last_trim_timestamp_ms, current_timestamp_ms)) { + return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it + } + + // We succesfully updated LAST_TRIM_TIMESTAMP_MS, we can trim + malloc_trim(pad); +#endif +} + +void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { +#ifdef USE_JEMALLOC + if (!allocator_background_threads) { + JemallocExtension::ThreadFlush(threshold); + } +#endif + MallocTrim(thread_count * threshold); +} + +void Allocator::ThreadIdle() { +#ifdef USE_JEMALLOC + JemallocExtension::ThreadIdle(); #endif } void Allocator::FlushAll() { #ifdef USE_JEMALLOC JemallocExtension::FlushAll(); +#endif + MallocTrim(0); +} + +void Allocator::SetBackgroundThreads(bool enable) { +#ifdef USE_JEMALLOC + JemallocExtension::SetBackgroundThreads(enable); #endif } diff --git a/src/duckdb/src/common/arrow/appender/bool_data.cpp b/src/duckdb/src/common/arrow/appender/bool_data.cpp index d30b3933..78befb60 100644 --- a/src/duckdb/src/common/arrow/appender/bool_data.cpp +++ b/src/duckdb/src/common/arrow/appender/bool_data.cpp @@ -5,7 +5,7 @@ namespace duckdb { void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { auto byte_count = (capacity + 7) / 8; - result.main_buffer.reserve(byte_count); + result.GetMainBuffer().reserve(byte_count); (void)AppendValidity; // silence a compiler warning about unused static function } @@ -13,14 +13,15 @@ void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t fr idx_t size = to - from; UnifiedVectorFormat format; input.ToUnifiedFormat(input_size, format); - + auto &main_buffer = append_data.GetMainBuffer(); + auto &validity_buffer = append_data.GetValidityBuffer(); // we initialize both the validity and the bit set to 1's - ResizeValidity(append_data.validity, append_data.row_count + size); - ResizeValidity(append_data.main_buffer, append_data.row_count + size); + ResizeValidity(validity_buffer, append_data.row_count + size); + ResizeValidity(main_buffer, append_data.row_count + size); auto data = UnifiedVectorFormat::GetData(format); - auto result_data = append_data.main_buffer.GetData(); - auto validity_data = append_data.validity.GetData(); + auto result_data = main_buffer.GetData(); + auto validity_data = validity_buffer.GetData(); uint8_t current_bit; idx_t current_byte; GetBitPosition(append_data.row_count, current_byte, current_bit); @@ -39,7 +40,7 @@ void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t fr void ArrowBoolData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { result->n_buffers = 2; - result->buffers[1] = append_data.main_buffer.data(); + result->buffers[1] = append_data.GetMainBuffer().data(); } } // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp index d546e16d..172144fd 100644 --- a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp +++ b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp @@ -19,7 +19,7 @@ void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, Vector &input, input.ToUnifiedFormat(input_size, format); idx_t size = to - from; AppendValidity(append_data, format, from, to); - + input.Flatten(input_size); auto array_size = ArrayType::GetSize(input.GetType()); auto &child_vector = ArrayVector::GetEntry(input); auto &child_data = *append_data.child_data[0]; diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp index 02acffe0..1e9f4f43 100644 --- a/src/duckdb/src/common/arrow/appender/union_data.cpp +++ b/src/duckdb/src/common/arrow/appender/union_data.cpp @@ -8,7 +8,7 @@ namespace duckdb { // Unions //===--------------------------------------------------------------------===// void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.main_buffer.reserve(capacity * sizeof(int8_t)); + result.GetMainBuffer().reserve(capacity * sizeof(int8_t)); for (auto &child : UnionType::CopyMemberTypes(type)) { auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); @@ -22,7 +22,7 @@ void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t f input.ToUnifiedFormat(input_size, format); idx_t size = to - from; - auto &types_buffer = append_data.main_buffer; + auto &types_buffer = append_data.GetMainBuffer(); duckdb::vector child_vectors; for (const auto &child : UnionType::CopyMemberTypes(input.GetType())) { @@ -43,8 +43,7 @@ void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t f for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { child_vectors[child_idx].SetValue(input_idx, child_idx == tag ? resolved_value : Value(nullptr)); } - - types_buffer.data()[input_idx] = NumericCast(tag); + types_buffer.push_back(NumericCast(tag)); } for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { @@ -57,7 +56,7 @@ void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t f void ArrowUnionData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { result->n_buffers = 1; - result->buffers[0] = append_data.main_buffer.data(); + result->buffers[0] = append_data.GetMainBuffer().data(); auto &child_types = UnionType::CopyMemberTypes(type); ArrowAppender::AddChildren(append_data, child_types.size()); diff --git a/src/duckdb/src/common/arrow/arrow_appender.cpp b/src/duckdb/src/common/arrow/arrow_appender.cpp index 8dd1f0cf..b478fdb3 100644 --- a/src/duckdb/src/common/arrow/arrow_appender.cpp +++ b/src/duckdb/src/common/arrow/arrow_appender.cpp @@ -14,10 +14,10 @@ namespace duckdb { // ArrowAppender //===--------------------------------------------------------------------===// -ArrowAppender::ArrowAppender(vector types_p, idx_t initial_capacity, ClientProperties options) +ArrowAppender::ArrowAppender(vector types_p, const idx_t initial_capacity, ClientProperties options) : types(std::move(types_p)) { for (auto &type : types) { - auto entry = ArrowAppender::InitializeChild(type, initial_capacity, options); + auto entry = InitializeChild(type, initial_capacity, options); root_data.push_back(std::move(entry)); } } @@ -35,6 +35,10 @@ void ArrowAppender::Append(DataChunk &input, idx_t from, idx_t to, idx_t input_s row_count += to - from; } +idx_t ArrowAppender::RowCount() const { + return row_count; +} + void ArrowAppender::ReleaseArray(ArrowArray *array) { if (!array || !array->release) { return; @@ -64,7 +68,7 @@ ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, unique_ptrprivate_data = append_data_p.release(); - result->release = ArrowAppender::ReleaseArray; + result->release = ReleaseArray; result->n_children = 0; result->null_count = 0; result->offset = 0; @@ -72,7 +76,7 @@ ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, unique_ptrbuffers = append_data.buffers.data(); result->null_count = NumericCast(append_data.null_count); result->length = NumericCast(append_data.row_count); - result->buffers[0] = append_data.validity.data(); + result->buffers[0] = append_data.GetValidityBuffer().data(); if (append_data.finalize) { append_data.finalize(append_data, type, result.get()); @@ -138,9 +142,14 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic case LogicalTypeId::INTEGER: InitializeAppenderForType>(append_data); break; - case LogicalTypeId::TIME_TZ: - InitializeAppenderForType>(append_data); + case LogicalTypeId::TIME_TZ: { + if (append_data.options.arrow_lossless_conversion) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } break; + } case LogicalTypeId::TIME: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_MS: @@ -150,9 +159,23 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic case LogicalTypeId::BIGINT: InitializeAppenderForType>(append_data); break; + case LogicalTypeId::UUID: + if (append_data.options.arrow_lossless_conversion) { + InitializeAppenderForType>(append_data); + } else { + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } + } + break; case LogicalTypeId::HUGEINT: InitializeAppenderForType>(append_data); break; + case LogicalTypeId::UHUGEINT: + InitializeAppenderForType>(append_data); + break; case LogicalTypeId::UTINYINT: InitializeAppenderForType>(append_data); break; @@ -190,19 +213,22 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic } break; case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); + if (append_data.options.produce_arrow_string_view) { + InitializeAppenderForType(append_data); } else { - InitializeAppenderForType>(append_data); + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } } break; - case LogicalTypeId::UUID: + case LogicalTypeId::BLOB: + case LogicalTypeId::BIT: if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); + InitializeAppenderForType>(append_data); } else { - InitializeAppenderForType>(append_data); + InitializeAppenderForType>(append_data); } break; case LogicalTypeId::ENUM: @@ -233,10 +259,18 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic InitializeAppenderForType(append_data); break; case LogicalTypeId::LIST: { - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); + if (append_data.options.arrow_use_list_view) { + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } } else { - InitializeAppenderForType>(append_data); + if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { + InitializeAppenderForType>(append_data); + } else { + InitializeAppenderForType>(append_data); + } } break; } @@ -249,18 +283,18 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic } } -unique_ptr ArrowAppender::InitializeChild(const LogicalType &type, idx_t capacity, +unique_ptr ArrowAppender::InitializeChild(const LogicalType &type, const idx_t capacity, ClientProperties &options) { auto result = make_uniq(options); InitializeFunctionPointers(*result, type); - auto byte_count = (capacity + 7) / 8; - result->validity.reserve(byte_count); + const auto byte_count = (capacity + 7) / 8; + result->GetValidityBuffer().reserve(byte_count); result->initialize(*result, type, capacity); return result; } -void ArrowAppender::AddChildren(ArrowAppendData &data, idx_t count) { +void ArrowAppender::AddChildren(ArrowAppendData &data, const idx_t count) { data.child_pointers.resize(count); data.child_arrays.resize(count); for (idx_t i = 0; i < count; i++) { diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index 5e807f6d..3524dc89 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -11,6 +11,7 @@ #include "duckdb/common/vector.hpp" #include #include "duckdb/common/arrow/arrow_appender.hpp" +#include "duckdb/common/arrow/schema_metadata.hpp" namespace duckdb { @@ -43,6 +44,8 @@ struct DuckDBArrowSchemaHolder { //! This holds strings created to represent decimal types vector> owned_type_names; vector> owned_column_names; + //! This holds any values created for metadata info + vector> metadata_info; }; static void ReleaseDuckDBArrowSchema(ArrowSchema *schema) { @@ -121,24 +124,77 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co case LogicalTypeId::FLOAT: child.format = "f"; break; - case LogicalTypeId::HUGEINT: - child.format = "d:38,0"; + case LogicalTypeId::HUGEINT: { + if (options.arrow_lossless_conversion) { + child.format = "w:16"; + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("duckdb.hugeint"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); + } else { + child.format = "d:38,0"; + } + break; + } + case LogicalTypeId::UHUGEINT: { + child.format = "w:16"; + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("duckdb.uhugeint"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); break; + } case LogicalTypeId::DOUBLE: child.format = "g"; break; - case LogicalTypeId::UUID: + case LogicalTypeId::UUID: { + if (options.arrow_lossless_conversion) { + // This is a canonical extension, hence needs the "arrow." prefix + child.format = "w:16"; + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("arrow.uuid"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); + } else { + if (options.produce_arrow_string_view) { + child.format = "vu"; + } else { + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "U"; + } else { + child.format = "u"; + } + } + } + break; + } case LogicalTypeId::VARCHAR: - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "U"; + if (type.IsJSONType()) { + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("arrow.json"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); + } + if (options.produce_arrow_string_view) { + child.format = "vu"; } else { - child.format = "u"; + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "U"; + } else { + child.format = "u"; + } } break; case LogicalTypeId::DATE: child.format = "tdD"; break; - case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIME_TZ: { + if (options.arrow_lossless_conversion) { + child.format = "w:8"; + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("duckdb.time_tz"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); + } else { + child.format = "ttu"; + } + break; + } case LogicalTypeId::TIME: child.format = "ttu"; break; @@ -176,19 +232,38 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co break; } case LogicalTypeId::BLOB: + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "Z"; + } else { + child.format = "z"; + } + break; case LogicalTypeId::BIT: { if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { child.format = "Z"; } else { child.format = "z"; } + if (options.arrow_lossless_conversion) { + auto schema_metadata = ArrowSchemaMetadata::MetadataFromName("duckdb.bit"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + child.metadata = root_holder.metadata_info.back().get(); + } break; } case LogicalTypeId::LIST: { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "+L"; + if (options.arrow_use_list_view) { + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "+vL"; + } else { + child.format = "+vl"; + } } else { - child.format = "+l"; + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + child.format = "+L"; + } else { + child.format = "+l"; + } } child.n_children = 1; root_holder.nested_children.emplace_back(); diff --git a/src/duckdb/src/common/arrow/arrow_merge_event.cpp b/src/duckdb/src/common/arrow/arrow_merge_event.cpp new file mode 100644 index 00000000..e899d818 --- /dev/null +++ b/src/duckdb/src/common/arrow/arrow_merge_event.cpp @@ -0,0 +1,142 @@ +#include "duckdb/common/arrow/arrow_merge_event.hpp" +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Arrow Batch Task +//===--------------------------------------------------------------------===// + +ArrowBatchTask::ArrowBatchTask(ArrowQueryResult &result, vector record_batch_indices, Executor &executor, + shared_ptr event_p, BatchCollectionChunkScanState scan_state, + vector names, idx_t batch_size) + : ExecutorTask(executor, event_p), result(result), record_batch_indices(std::move(record_batch_indices)), + event(std::move(event_p)), batch_size(batch_size), names(std::move(names)), scan_state(std::move(scan_state)) { +} + +void ArrowBatchTask::ProduceRecordBatches() { + auto &arrays = result.Arrays(); + auto arrow_options = executor.context.GetClientProperties(); + for (auto &index : record_batch_indices) { + auto &array = arrays[index]; + D_ASSERT(array); + idx_t count; + count = ArrowUtil::FetchChunk(scan_state, arrow_options, batch_size, &array->arrow_array); + (void)count; + D_ASSERT(count != 0); + } +} + +TaskExecutionResult ArrowBatchTask::ExecuteTask(TaskExecutionMode mode) { + ProduceRecordBatches(); + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +//===--------------------------------------------------------------------===// +// Arrow Merge Event +//===--------------------------------------------------------------------===// + +ArrowMergeEvent::ArrowMergeEvent(ArrowQueryResult &result, BatchedDataCollection &batches, Pipeline &pipeline_p) + : BasePipelineEvent(pipeline_p), result(result), batches(batches) { + record_batch_size = result.BatchSize(); +} + +namespace { + +struct BatchesForTask { + idx_t tuple_count; + BatchedChunkIteratorRange batches; +}; + +struct BatchesToTaskTransformer { +public: + explicit BatchesToTaskTransformer(BatchedDataCollection &batches) : batches(batches), batch_index(0) { + batch_count = batches.BatchCount(); + } + idx_t GetIndex() const { + return batch_index; + } + bool TryGetNextBatchSize(idx_t &tuple_count) { + if (batch_index >= batch_count) { + return false; + } + auto internal_index = batches.IndexToBatchIndex(batch_index++); + auto tuples_in_batch = batches.BatchSize(internal_index); + tuple_count = tuples_in_batch; + return true; + } + +public: + BatchedDataCollection &batches; + idx_t batch_index; + idx_t batch_count; +}; + +} // namespace + +void ArrowMergeEvent::Schedule() { + vector> tasks; + + BatchesToTaskTransformer transformer(batches); + vector task_data; + bool finished = false; + // First we convert our list of batches into units of Storage::ROW_GROUP_SIZE tuples each + while (!finished) { + idx_t tuples_for_task = 0; + idx_t start_index = transformer.GetIndex(); + idx_t end_index = start_index; + while (tuples_for_task < Storage::ROW_GROUP_SIZE) { + idx_t batch_size; + if (!transformer.TryGetNextBatchSize(batch_size)) { + finished = true; + break; + } + end_index++; + tuples_for_task += batch_size; + } + if (start_index == end_index) { + break; + } + BatchesForTask batches_for_task; + batches_for_task.tuple_count = tuples_for_task; + batches_for_task.batches = batches.BatchRange(start_index, end_index); + task_data.push_back(batches_for_task); + } + + // Now we produce tasks from these units + // Every task is given a scan_state created from the range of batches + // and a vector of indices indicating the arrays (record batches) they should populate + idx_t record_batch_index = 0; + for (auto &data : task_data) { + const auto tuples = data.tuple_count; + + auto full_batches = tuples / record_batch_size; + auto remainder = tuples % record_batch_size; + auto total_batches = full_batches + !!remainder; + + vector record_batch_indices(total_batches); + for (idx_t i = 0; i < total_batches; i++) { + record_batch_indices[i] = record_batch_index++; + } + + BatchCollectionChunkScanState scan_state(batches, data.batches, pipeline->executor.context); + tasks.push_back(make_uniq(result, std::move(record_batch_indices), pipeline->executor, + shared_from_this(), std::move(scan_state), result.names, + record_batch_size)); + } + + // Allocate the list of record batches inside the query result + { + vector> arrays; + arrays.resize(record_batch_index); + for (idx_t i = 0; i < record_batch_index; i++) { + arrays[i] = make_uniq(); + } + result.SetArrowData(std::move(arrays)); + } + D_ASSERT(!tasks.empty()); + SetTasks(std::move(tasks)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_query_result.cpp b/src/duckdb/src/common/arrow/arrow_query_result.cpp new file mode 100644 index 00000000..396a9994 --- /dev/null +++ b/src/duckdb/src/common/arrow/arrow_query_result.cpp @@ -0,0 +1,56 @@ +#include "duckdb/common/arrow/arrow_query_result.hpp" +#include "duckdb/common/to_string.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/box_renderer.hpp" +#include "duckdb/common/arrow/arrow_converter.hpp" + +namespace duckdb { + +ArrowQueryResult::ArrowQueryResult(StatementType statement_type, StatementProperties properties, vector names_p, + vector types_p, ClientProperties client_properties, idx_t batch_size) + : QueryResult(QueryResultType::ARROW_RESULT, statement_type, std::move(properties), std::move(types_p), + std::move(names_p), std::move(client_properties)), + batch_size(batch_size) { +} + +ArrowQueryResult::ArrowQueryResult(ErrorData error) : QueryResult(QueryResultType::ARROW_RESULT, std::move(error)) { +} + +unique_ptr ArrowQueryResult::Fetch() { + throw NotImplementedException("Can't 'Fetch' from ArrowQueryResult"); +} +unique_ptr ArrowQueryResult::FetchRaw() { + throw NotImplementedException("Can't 'FetchRaw' from ArrowQueryResult"); +} + +string ArrowQueryResult::ToString() { + // FIXME: can't throw an exception here as it's used for verification + return ""; +} + +vector> ArrowQueryResult::ConsumeArrays() { + if (HasError()) { + throw InvalidInputException("Attempting to fetch ArrowArrays from an unsuccessful query result\n: Error %s", + GetError()); + } + return std::move(arrays); +} + +vector> &ArrowQueryResult::Arrays() { + if (HasError()) { + throw InvalidInputException("Attempting to fetch ArrowArrays from an unsuccessful query result\n: Error %s", + GetError()); + } + return arrays; +} + +void ArrowQueryResult::SetArrowData(vector> arrays) { + D_ASSERT(this->arrays.empty()); + this->arrays = std::move(arrays); +} + +idx_t ArrowQueryResult::BatchSize() const { + return batch_size; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp new file mode 100644 index 00000000..11406c54 --- /dev/null +++ b/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp @@ -0,0 +1,37 @@ +#include "duckdb/common/arrow/physical_arrow_batch_collector.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/common/arrow/arrow_query_result.hpp" +#include "duckdb/common/arrow/arrow_merge_event.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/arrow/physical_arrow_collector.hpp" + +namespace duckdb { + +unique_ptr PhysicalArrowBatchCollector::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(context, *this); +} + +SinkFinalizeType PhysicalArrowBatchCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + + auto total_tuple_count = gstate.data.Count(); + if (total_tuple_count == 0) { + // Create the result containing a single empty result conversion + gstate.result = make_uniq(statement_type, properties, names, types, + context.GetClientProperties(), record_batch_size); + return SinkFinalizeType::READY; + } + + // Already create the final query result + gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), + record_batch_size); + // Spawn an event that will populate the conversion result + auto &arrow_result = gstate.result->Cast(); + auto new_event = make_shared_ptr(arrow_result, gstate.data, pipeline); + event.InsertEvent(std::move(new_event)); + + return SinkFinalizeType::READY; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp new file mode 100644 index 00000000..d82246b4 --- /dev/null +++ b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp @@ -0,0 +1,128 @@ +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/arrow/physical_arrow_collector.hpp" +#include "duckdb/common/arrow/physical_arrow_batch_collector.hpp" +#include "duckdb/common/arrow/arrow_query_result.hpp" +#include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +unique_ptr PhysicalArrowCollector::Create(ClientContext &context, PreparedStatementData &data, + idx_t batch_size) { + if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, *data.plan)) { + // the plan is not order preserving, so we just use the parallel materialized collector + return make_uniq_base(data, true, batch_size); + } else if (!PhysicalPlanGenerator::UseBatchIndex(context, *data.plan)) { + // the plan is order preserving, but we cannot use the batch index: use a single-threaded result collector + return make_uniq_base(data, false, batch_size); + } else { + return make_uniq_base(data, batch_size); + } +} + +SinkResultType PhysicalArrowCollector::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &lstate = input.local_state.Cast(); + // Append to the appender, up to chunk size + + auto count = chunk.size(); + auto &appender = lstate.appender; + D_ASSERT(count != 0); + + idx_t processed = 0; + do { + if (!appender) { + // Create the appender if we haven't started this chunk yet + auto properties = context.client.GetClientProperties(); + D_ASSERT(processed < count); + auto initial_capacity = MinValue(record_batch_size, count - processed); + appender = make_uniq(types, initial_capacity, properties); + } + + // Figure out how much we can still append to this chunk + auto row_count = appender->RowCount(); + D_ASSERT(record_batch_size > row_count); + auto to_append = MinValue(record_batch_size - row_count, count - processed); + + // Append and check if the chunk is finished + appender->Append(chunk, processed, processed + to_append, count); + processed += to_append; + row_count = appender->RowCount(); + if (row_count >= record_batch_size) { + lstate.FinishArray(); + } + } while (processed < count); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType PhysicalArrowCollector::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + auto &last_appender = lstate.appender; + auto &arrays = lstate.finished_arrays; + if (arrays.empty() && !last_appender) { + // Nothing to do + return SinkCombineResultType::FINISHED; + } + if (last_appender) { + // FIXME: we could set these aside and merge them in a finalize event in an effort to create more balanced + // chunks out of these remnants + lstate.FinishArray(); + } + // Collect all the finished arrays + lock_guard l(gstate.glock); + // Move the arrays from our local state into the global state + gstate.chunks.insert(gstate.chunks.end(), std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); + arrays.clear(); + gstate.tuple_count += lstate.tuple_count; + return SinkCombineResultType::FINISHED; +} + +unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) { + auto &gstate = state_p.Cast(); + return std::move(gstate.result); +} + +unique_ptr PhysicalArrowCollector::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(); +} + +unique_ptr PhysicalArrowCollector::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(); +} + +SinkFinalizeType PhysicalArrowCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, + OperatorSinkFinalizeInput &input) const { + auto &gstate = input.global_state.Cast(); + + if (gstate.chunks.empty()) { + if (gstate.tuple_count != 0) { + throw InternalException( + "PhysicalArrowCollector Finalize contains no chunks, but tuple_count is non-zero (%d)", + gstate.tuple_count); + } + gstate.result = make_uniq(statement_type, properties, names, types, + context.GetClientProperties(), record_batch_size); + return SinkFinalizeType::READY; + } + + gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), + record_batch_size); + auto &arrow_result = gstate.result->Cast(); + arrow_result.SetArrowData(std::move(gstate.chunks)); + + return SinkFinalizeType::READY; +} + +bool PhysicalArrowCollector::ParallelSink() const { + return parallel; +} + +bool PhysicalArrowCollector::SinkOrderDependent() const { + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/schema_metadata.cpp b/src/duckdb/src/common/arrow/schema_metadata.cpp new file mode 100644 index 00000000..acbf75c5 --- /dev/null +++ b/src/duckdb/src/common/arrow/schema_metadata.cpp @@ -0,0 +1,101 @@ +#include "duckdb/common/arrow/schema_metadata.hpp" + +namespace duckdb { +ArrowSchemaMetadata::ArrowSchemaMetadata(const char *metadata) { + if (metadata) { + // Read the number of key-value pairs (int32) + int32_t num_pairs; + memcpy(&num_pairs, metadata, sizeof(int32_t)); + metadata += sizeof(int32_t); + + // Loop through each key-value pair + for (int32_t i = 0; i < num_pairs; ++i) { + // Read the length of the key (int32) + int32_t key_length; + memcpy(&key_length, metadata, sizeof(int32_t)); + metadata += sizeof(int32_t); + + // Read the key + std::string key(metadata, static_cast(key_length)); + metadata += key_length; + + // Read the length of the value (int32) + int32_t value_length; + memcpy(&value_length, metadata, sizeof(int32_t)); + metadata += sizeof(int32_t); + + // Read the value + const std::string value(metadata, static_cast(value_length)); + metadata += value_length; + metadata_map[key] = value; + } + } +} + +void ArrowSchemaMetadata::AddOption(const string &key, const string &value) { + metadata_map[key] = value; +} +string ArrowSchemaMetadata::GetOption(const string &key) const { + return metadata_map.at(key); +} + +string ArrowSchemaMetadata::GetExtensionName() const { + return GetOption(ARROW_EXTENSION_NAME); +} + +ArrowSchemaMetadata ArrowSchemaMetadata::MetadataFromName(const string &extension_name) { + ArrowSchemaMetadata metadata; + metadata.AddOption(ArrowSchemaMetadata::ARROW_EXTENSION_NAME, extension_name); + metadata.AddOption(ArrowSchemaMetadata::ARROW_METADATA_KEY, ""); + return metadata; +} + +bool ArrowSchemaMetadata::HasExtension() { + if (metadata_map.find(ARROW_EXTENSION_NAME) == metadata_map.end()) { + return false; + } + auto arrow_extension = GetOption(ArrowSchemaMetadata::ARROW_EXTENSION_NAME); + // FIXME: We are currently ignoring the ogc extensions + return !arrow_extension.empty() && !StringUtil::StartsWith(arrow_extension, "ogc"); +} + +unsafe_unique_array ArrowSchemaMetadata::SerializeMetadata() const { + // First we have to figure out the total size: + // 1. number of key-value pairs (int32) + idx_t total_size = sizeof(int32_t); + for (const auto &option : metadata_map) { + // 2. Length of the key and value (2 * int32) + total_size += 2 * sizeof(int32_t); + // 3. Length of key + total_size += option.first.size(); + // 4. Length of value + total_size += option.second.size(); + } + auto metadata_array_ptr = make_unsafe_uniq_array(total_size); + auto metadata_ptr = metadata_array_ptr.get(); + // 1. number of key-value pairs (int32) + const idx_t map_size = metadata_map.size(); + memcpy(metadata_ptr, &map_size, sizeof(int32_t)); + metadata_ptr += sizeof(int32_t); + // Iterate through each key-value pair in the map + for (const auto &pair : metadata_map) { + const std::string &key = pair.first; + idx_t key_size = key.size(); + // Length of the key (int32) + memcpy(metadata_ptr, &key_size, sizeof(int32_t)); + metadata_ptr += sizeof(int32_t); + // Key + memcpy(metadata_ptr, key.c_str(), key_size); + metadata_ptr += key_size; + const std::string &value = pair.second; + const idx_t value_size = value.size(); + // Length of the value (int32) + memcpy(metadata_ptr, &value_size, sizeof(int32_t)); + metadata_ptr += sizeof(int32_t); + // Value + memcpy(metadata_ptr, value.c_str(), value_size); + metadata_ptr += value_size; + } + return metadata_array_ptr; +} +} // namespace duckdb diff --git a/src/duckdb/src/common/cgroups.cpp b/src/duckdb/src/common/cgroups.cpp new file mode 100644 index 00000000..b9d2b820 --- /dev/null +++ b/src/duckdb/src/common/cgroups.cpp @@ -0,0 +1,189 @@ +#include "duckdb/common/cgroups.hpp" + +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/cast_operators.hpp" + +#include + +namespace duckdb { + +optional_idx CGroups::GetMemoryLimit(FileSystem &fs) { + // First, try cgroup v2 + auto cgroup_v2_limit = GetCGroupV2MemoryLimit(fs); + if (cgroup_v2_limit.IsValid()) { + return cgroup_v2_limit; + } + + // If cgroup v2 fails, try cgroup v1 + return GetCGroupV1MemoryLimit(fs); +} + +optional_idx CGroups::GetCGroupV2MemoryLimit(FileSystem &fs) { +#ifdef DUCKDB_WASM + return optional_idx(); +#else + const char *cgroup_self = "/proc/self/cgroup"; + const char *memory_max = "/sys/fs/cgroup/%s/memory.max"; + + if (!fs.FileExists(cgroup_self)) { + return optional_idx(); + } + + string cgroup_path = ReadCGroupPath(fs, cgroup_self); + if (cgroup_path.empty()) { + return optional_idx(); + } + + char memory_max_path[256]; + snprintf(memory_max_path, sizeof(memory_max_path), memory_max, cgroup_path.c_str()); + + if (!fs.FileExists(memory_max_path)) { + return optional_idx(); + } + + return ReadCGroupValue(fs, memory_max_path); +#endif +} + +optional_idx CGroups::GetCGroupV1MemoryLimit(FileSystem &fs) { +#ifdef DUCKDB_WASM + return optional_idx(); +#else + const char *cgroup_self = "/proc/self/cgroup"; + const char *memory_limit = "/sys/fs/cgroup/memory/%s/memory.limit_in_bytes"; + + if (!fs.FileExists(cgroup_self)) { + return optional_idx(); + } + + string memory_cgroup_path = ReadMemoryCGroupPath(fs, cgroup_self); + if (memory_cgroup_path.empty()) { + return optional_idx(); + } + + char memory_limit_path[256]; + snprintf(memory_limit_path, sizeof(memory_limit_path), memory_limit, memory_cgroup_path.c_str()); + + if (!fs.FileExists(memory_limit_path)) { + return optional_idx(); + } + + return ReadCGroupValue(fs, memory_limit_path); +#endif +} + +string CGroups::ReadCGroupPath(FileSystem &fs, const char *cgroup_file) { +#ifdef DUCKDB_WASM + return ""; +#else + auto handle = fs.OpenFile(cgroup_file, FileFlags::FILE_FLAGS_READ); + char buffer[1024]; + auto bytes_read = fs.Read(*handle, buffer, sizeof(buffer) - 1); + buffer[bytes_read] = '\0'; + + // For cgroup v2, we're looking for a single line with "0::/path" + string content(buffer); + auto pos = content.find("::"); + if (pos != string::npos) { + return content.substr(pos + 2); + } + + return ""; +#endif +} + +string CGroups::ReadMemoryCGroupPath(FileSystem &fs, const char *cgroup_file) { +#ifdef DUCKDB_WASM + return ""; +#else + auto handle = fs.OpenFile(cgroup_file, FileFlags::FILE_FLAGS_READ); + char buffer[1024]; + auto bytes_read = fs.Read(*handle, buffer, sizeof(buffer) - 1); + buffer[bytes_read] = '\0'; + + // For cgroup v1, we're looking for a line with "memory:/path" + string content(buffer); + size_t pos = 0; + string line; + while ((pos = content.find('\n')) != string::npos) { + line = content.substr(0, pos); + if (line.find("memory:") == 0) { + return line.substr(line.find(':') + 1); + } + content.erase(0, pos + 1); + } + + return ""; +#endif +} + +optional_idx CGroups::ReadCGroupValue(FileSystem &fs, const char *file_path) { +#ifdef DUCKDB_WASM + return optional_idx(); +#else + auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ); + char buffer[100]; + auto bytes_read = fs.Read(*handle, buffer, 99); + buffer[bytes_read] = '\0'; + + idx_t value; + if (TryCast::Operation(string_t(buffer), value)) { + return optional_idx(value); + } + return optional_idx(); +#endif +} + +idx_t CGroups::GetCPULimit(FileSystem &fs, idx_t physical_cores) { +#ifdef DUCKDB_WASM + return physical_cores; +#else + + static constexpr const char *cpu_max = "/sys/fs/cgroup/cpu.max"; + static constexpr const char *cfs_quota = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"; + static constexpr const char *cfs_period = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"; + + int64_t quota, period; + char byte_buffer[1000]; + unique_ptr handle; + int64_t read_bytes; + + if (fs.FileExists(cpu_max)) { + // cgroup v2 + handle = fs.OpenFile(cpu_max, FileFlags::FILE_FLAGS_READ); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 " %" SCNd64 "", "a, &period) != 2) { + return physical_cores; + } + } else if (fs.FileExists(cfs_quota) && fs.FileExists(cfs_period)) { + // cgroup v1 + handle = fs.OpenFile(cfs_quota, FileFlags::FILE_FLAGS_READ); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 "", "a) != 1) { + return physical_cores; + } + + handle = fs.OpenFile(cfs_period, FileFlags::FILE_FLAGS_READ); + read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); + byte_buffer[read_bytes] = '\0'; + if (std::sscanf(byte_buffer, "%" SCNd64 "", &period) != 1) { + return physical_cores; + } + } else { + // No cgroup quota + return physical_cores; + } + if (quota > 0 && period > 0) { + return idx_t(std::ceil((double)quota / (double)period)); + } else { + return physical_cores; + } +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/compressed_file_system.cpp b/src/duckdb/src/common/compressed_file_system.cpp index 5727d4d7..b34c6c21 100644 --- a/src/duckdb/src/common/compressed_file_system.cpp +++ b/src/duckdb/src/common/compressed_file_system.cpp @@ -8,7 +8,6 @@ StreamWrapper::~StreamWrapper() { CompressedFile::CompressedFile(CompressedFileSystem &fs, unique_ptr child_handle_p, const string &path) : FileHandle(fs, path), compressed_fs(fs), child_handle(std::move(child_handle_p)) { - D_ASSERT(child_handle->SeekPosition() == 0); } CompressedFile::~CompressedFile() { @@ -32,6 +31,10 @@ void CompressedFile::Initialize(bool write) { stream_wrapper->Initialize(*this, write); } +idx_t CompressedFile::GetProgress() { + return current_position; +} + int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { idx_t total_read = 0; while (true) { @@ -46,7 +49,7 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { // increment the total read variables as required stream_data.out_buff_start += available; total_read += available; - remaining -= available; + remaining = UnsafeNumericCast(UnsafeNumericCast(remaining) - available); if (remaining == 0) { // done! read enough return UnsafeNumericCast(total_read); @@ -55,7 +58,7 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { if (!stream_wrapper) { return UnsafeNumericCast(total_read); } - + current_position += static_cast(stream_data.in_buff_end - stream_data.in_buff_start); // ran out of buffer: read more data from the child stream stream_data.out_buff_start = stream_data.out_buff.get(); stream_data.out_buff_end = stream_data.out_buff.get(); diff --git a/src/duckdb/src/common/encryption_state.cpp b/src/duckdb/src/common/encryption_state.cpp new file mode 100644 index 00000000..b6343b63 --- /dev/null +++ b/src/duckdb/src/common/encryption_state.cpp @@ -0,0 +1,38 @@ +#include "duckdb/common/encryption_state.hpp" + +namespace duckdb { + +EncryptionState::EncryptionState() { + // abstract class, no implementation needed +} + +EncryptionState::~EncryptionState() { +} + +bool EncryptionState::IsOpenSSL() { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +void EncryptionState::InitializeEncryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, const std::string *key) { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +void EncryptionState::InitializeDecryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, const std::string *key) { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +size_t EncryptionState::Process(duckdb::const_data_ptr_t in, duckdb::idx_t in_len, duckdb::data_ptr_t out, + duckdb::idx_t out_len) { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +size_t EncryptionState::Finalize(duckdb::data_ptr_t out, duckdb::idx_t out_len, duckdb::data_ptr_t tag, + duckdb::idx_t tag_len) { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +void EncryptionState::GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) { + throw NotImplementedException("EncryptionState Abstract Class is called"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 33b64469..4023680b 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -22,6 +22,8 @@ #include "duckdb/common/enums/cte_materialize.hpp" #include "duckdb/common/enums/date_part_specifier.hpp" #include "duckdb/common/enums/debug_initialize.hpp" +#include "duckdb/common/enums/destroy_buffer_upon.hpp" +#include "duckdb/common/enums/explain_format.hpp" #include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/enums/file_compression_type.hpp" #include "duckdb/common/enums/file_glob_options.hpp" @@ -31,6 +33,7 @@ #include "duckdb/common/enums/joinref_type.hpp" #include "duckdb/common/enums/logical_operator_type.hpp" #include "duckdb/common/enums/memory_tag.hpp" +#include "duckdb/common/enums/metric_type.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" #include "duckdb/common/enums/operator_result_type.hpp" @@ -48,6 +51,7 @@ #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/common/enums/set_type.hpp" #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/enums/stream_execution_result.hpp" #include "duckdb/common/enums/subquery_type.hpp" #include "duckdb/common/enums/tableref_type.hpp" #include "duckdb/common/enums/undo_flags.hpp" @@ -81,23 +85,28 @@ #include "duckdb/execution/operator/csv_scanner/quote_rules.hpp" #include "duckdb/execution/reservoir_sample.hpp" #include "duckdb/function/aggregate_state.hpp" +#include "duckdb/function/copy_function.hpp" #include "duckdb/function/function.hpp" #include "duckdb/function/macro_function.hpp" #include "duckdb/function/scalar/compressed_materialization_functions.hpp" #include "duckdb/function/scalar/strftime_format.hpp" -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/function/table/arrow/enum/arrow_datetime_type.hpp" +#include "duckdb/function/table/arrow/enum/arrow_type_info_type.hpp" +#include "duckdb/function/table/arrow/enum/arrow_variable_size_type.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/main/appender.hpp" #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/client_properties.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/error_manager.hpp" +#include "duckdb/main/extension.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/main/extension_install_info.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/task.hpp" #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/expression/parameter_expression.hpp" @@ -117,6 +126,7 @@ #include "duckdb/parser/query_node.hpp" #include "duckdb/parser/result_modifier.hpp" #include "duckdb/parser/simplified_token.hpp" +#include "duckdb/parser/statement/copy_statement.hpp" #include "duckdb/parser/statement/explain_statement.hpp" #include "duckdb/parser/statement/insert_statement.hpp" #include "duckdb/parser/tableref/showref.hpp" @@ -575,15 +585,55 @@ ArrowOffsetSize EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(ArrowTypeInfoType value) { + switch(value) { + case ArrowTypeInfoType::LIST: + return "LIST"; + case ArrowTypeInfoType::STRUCT: + return "STRUCT"; + case ArrowTypeInfoType::DATE_TIME: + return "DATE_TIME"; + case ArrowTypeInfoType::STRING: + return "STRING"; + case ArrowTypeInfoType::ARRAY: + return "ARRAY"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ArrowTypeInfoType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "LIST")) { + return ArrowTypeInfoType::LIST; + } + if (StringUtil::Equals(value, "STRUCT")) { + return ArrowTypeInfoType::STRUCT; + } + if (StringUtil::Equals(value, "DATE_TIME")) { + return ArrowTypeInfoType::DATE_TIME; + } + if (StringUtil::Equals(value, "STRING")) { + return ArrowTypeInfoType::STRING; + } + if (StringUtil::Equals(value, "ARRAY")) { + return ArrowTypeInfoType::ARRAY; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(ArrowVariableSizeType value) { switch(value) { - case ArrowVariableSizeType::FIXED_SIZE: - return "FIXED_SIZE"; case ArrowVariableSizeType::NORMAL: return "NORMAL"; + case ArrowVariableSizeType::FIXED_SIZE: + return "FIXED_SIZE"; case ArrowVariableSizeType::SUPER_SIZE: return "SUPER_SIZE"; + case ArrowVariableSizeType::VIEW: + return "VIEW"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -591,15 +641,18 @@ const char* EnumUtil::ToChars(ArrowVariableSizeType value template<> ArrowVariableSizeType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "FIXED_SIZE")) { - return ArrowVariableSizeType::FIXED_SIZE; - } if (StringUtil::Equals(value, "NORMAL")) { return ArrowVariableSizeType::NORMAL; } + if (StringUtil::Equals(value, "FIXED_SIZE")) { + return ArrowVariableSizeType::FIXED_SIZE; + } if (StringUtil::Equals(value, "SUPER_SIZE")) { return ArrowVariableSizeType::SUPER_SIZE; } + if (StringUtil::Equals(value, "VIEW")) { + return ArrowVariableSizeType::VIEW; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -633,6 +686,8 @@ const char* EnumUtil::ToChars(BindingMode value) { return "STANDARD_BINDING"; case BindingMode::EXTRACT_NAMES: return "EXTRACT_NAMES"; + case BindingMode::EXTRACT_REPLACEMENT_SCANS: + return "EXTRACT_REPLACEMENT_SCANS"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -646,6 +701,9 @@ BindingMode EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "EXTRACT_NAMES")) { return BindingMode::EXTRACT_NAMES; } + if (StringUtil::Equals(value, "EXTRACT_REPLACEMENT_SCANS")) { + return BindingMode::EXTRACT_REPLACEMENT_SCANS; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -773,6 +831,8 @@ const char* EnumUtil::ToChars(CSVState value) { return "QUOTED_NEW_LINE"; case CSVState::EMPTY_SPACE: return "EMPTY_SPACE"; + case CSVState::COMMENT: + return "COMMENT"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -813,6 +873,9 @@ CSVState EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "EMPTY_SPACE")) { return CSVState::EMPTY_SPACE; } + if (StringUtil::Equals(value, "COMMENT")) { + return CSVState::COMMENT; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -1307,6 +1370,29 @@ ConstraintType EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(CopyFunctionReturnType value) { + switch(value) { + case CopyFunctionReturnType::CHANGED_ROWS: + return "CHANGED_ROWS"; + case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: + return "CHANGED_ROWS_AND_FILE_LIST"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CopyFunctionReturnType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CHANGED_ROWS")) { + return CopyFunctionReturnType::CHANGED_ROWS; + } + if (StringUtil::Equals(value, "CHANGED_ROWS_AND_FILE_LIST")) { + return CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(CopyOverwriteMode value) { switch(value) { @@ -1316,6 +1402,8 @@ const char* EnumUtil::ToChars(CopyOverwriteMode value) { return "COPY_OVERWRITE"; case CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE: return "COPY_OVERWRITE_OR_IGNORE"; + case CopyOverwriteMode::COPY_APPEND: + return "COPY_APPEND"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -1332,6 +1420,32 @@ CopyOverwriteMode EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "COPY_OVERWRITE_OR_IGNORE")) { return CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE; } + if (StringUtil::Equals(value, "COPY_APPEND")) { + return CopyOverwriteMode::COPY_APPEND; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(CopyToType value) { + switch(value) { + case CopyToType::COPY_TO_FILE: + return "COPY_TO_FILE"; + case CopyToType::EXPORT_DATABASE: + return "EXPORT_DATABASE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +CopyToType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "COPY_TO_FILE")) { + return CopyToType::COPY_TO_FILE; + } + if (StringUtil::Equals(value, "EXPORT_DATABASE")) { + return CopyToType::EXPORT_DATABASE; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -1623,6 +1737,34 @@ DeprecatedIndexType EnumUtil::FromString(const char *value) throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(DestroyBufferUpon value) { + switch(value) { + case DestroyBufferUpon::BLOCK: + return "BLOCK"; + case DestroyBufferUpon::EVICTION: + return "EVICTION"; + case DestroyBufferUpon::UNPIN: + return "UNPIN"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +DestroyBufferUpon EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "BLOCK")) { + return DestroyBufferUpon::BLOCK; + } + if (StringUtil::Equals(value, "EVICTION")) { + return DestroyBufferUpon::EVICTION; + } + if (StringUtil::Equals(value, "UNPIN")) { + return DestroyBufferUpon::UNPIN; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(DistinctType value) { switch(value) { @@ -1799,6 +1941,8 @@ const char* EnumUtil::ToChars(ExceptionType value) { return "AUTOLOAD"; case ExceptionType::SEQUENCE: return "SEQUENCE"; + case ExceptionType::INVALID_CONFIGURATION: + return "INVALID_CONFIGURATION"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -1932,6 +2076,47 @@ ExceptionType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "SEQUENCE")) { return ExceptionType::SEQUENCE; } + if (StringUtil::Equals(value, "INVALID_CONFIGURATION")) { + return ExceptionType::INVALID_CONFIGURATION; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(ExplainFormat value) { + switch(value) { + case ExplainFormat::DEFAULT: + return "DEFAULT"; + case ExplainFormat::TEXT: + return "TEXT"; + case ExplainFormat::JSON: + return "JSON"; + case ExplainFormat::HTML: + return "HTML"; + case ExplainFormat::GRAPHVIZ: + return "GRAPHVIZ"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExplainFormat EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "DEFAULT")) { + return ExplainFormat::DEFAULT; + } + if (StringUtil::Equals(value, "TEXT")) { + return ExplainFormat::TEXT; + } + if (StringUtil::Equals(value, "JSON")) { + return ExplainFormat::JSON; + } + if (StringUtil::Equals(value, "HTML")) { + return ExplainFormat::HTML; + } + if (StringUtil::Equals(value, "GRAPHVIZ")) { + return ExplainFormat::GRAPHVIZ; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -2585,6 +2770,34 @@ ExpressionType EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(ExtensionABIType value) { + switch(value) { + case ExtensionABIType::UNKNOWN: + return "UNKNOWN"; + case ExtensionABIType::CPP: + return "CPP"; + case ExtensionABIType::C_STRUCT: + return "C_STRUCT"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +ExtensionABIType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "UNKNOWN")) { + return ExtensionABIType::UNKNOWN; + } + if (StringUtil::Equals(value, "CPP")) { + return ExtensionABIType::CPP; + } + if (StringUtil::Equals(value, "C_STRUCT")) { + return ExtensionABIType::C_STRUCT; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(ExtensionInstallMode value) { switch(value) { @@ -3057,11 +3270,36 @@ FunctionStability EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(GateStatus value) { + switch(value) { + case GateStatus::GATE_NOT_SET: + return "GATE_NOT_SET"; + case GateStatus::GATE_SET: + return "GATE_SET"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +GateStatus EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "GATE_NOT_SET")) { + return GateStatus::GATE_NOT_SET; + } + if (StringUtil::Equals(value, "GATE_SET")) { + return GateStatus::GATE_SET; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(HLLStorageType value) { switch(value) { - case HLLStorageType::UNCOMPRESSED: - return "UNCOMPRESSED"; + case HLLStorageType::HLL_V1: + return "HLL_V1"; + case HLLStorageType::HLL_V2: + return "HLL_V2"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -3069,8 +3307,11 @@ const char* EnumUtil::ToChars(HLLStorageType value) { template<> HLLStorageType EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "UNCOMPRESSED")) { - return HLLStorageType::UNCOMPRESSED; + if (StringUtil::Equals(value, "HLL_V1")) { + return HLLStorageType::HLL_V1; + } + if (StringUtil::Equals(value, "HLL_V2")) { + return HLLStorageType::HLL_V2; } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -3754,6 +3995,8 @@ const char* EnumUtil::ToChars(LogicalTypeId value) { return "STRING_LITERAL"; case LogicalTypeId::INTEGER_LITERAL: return "INTEGER_LITERAL"; + case LogicalTypeId::VARINT: + return "VARINT"; case LogicalTypeId::UHUGEINT: return "UHUGEINT"; case LogicalTypeId::HUGEINT: @@ -3885,6 +4128,9 @@ LogicalTypeId EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "INTEGER_LITERAL")) { return LogicalTypeId::INTEGER_LITERAL; } + if (StringUtil::Equals(value, "VARINT")) { + return LogicalTypeId::VARINT; + } if (StringUtil::Equals(value, "UHUGEINT")) { return LogicalTypeId::UHUGEINT; } @@ -4097,6 +4343,252 @@ MemoryTag EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(MetaPipelineType value) { + switch(value) { + case MetaPipelineType::REGULAR: + return "REGULAR"; + case MetaPipelineType::JOIN_BUILD: + return "JOIN_BUILD"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +MetaPipelineType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "REGULAR")) { + return MetaPipelineType::REGULAR; + } + if (StringUtil::Equals(value, "JOIN_BUILD")) { + return MetaPipelineType::JOIN_BUILD; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + +template<> +const char* EnumUtil::ToChars(MetricsType value) { + switch(value) { + case MetricsType::QUERY_NAME: + return "QUERY_NAME"; + case MetricsType::BLOCKED_THREAD_TIME: + return "BLOCKED_THREAD_TIME"; + case MetricsType::CPU_TIME: + return "CPU_TIME"; + case MetricsType::EXTRA_INFO: + return "EXTRA_INFO"; + case MetricsType::CUMULATIVE_CARDINALITY: + return "CUMULATIVE_CARDINALITY"; + case MetricsType::OPERATOR_TYPE: + return "OPERATOR_TYPE"; + case MetricsType::OPERATOR_CARDINALITY: + return "OPERATOR_CARDINALITY"; + case MetricsType::CUMULATIVE_ROWS_SCANNED: + return "CUMULATIVE_ROWS_SCANNED"; + case MetricsType::OPERATOR_ROWS_SCANNED: + return "OPERATOR_ROWS_SCANNED"; + case MetricsType::OPERATOR_TIMING: + return "OPERATOR_TIMING"; + case MetricsType::RESULT_SET_SIZE: + return "RESULT_SET_SIZE"; + case MetricsType::ALL_OPTIMIZERS: + return "ALL_OPTIMIZERS"; + case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: + return "CUMULATIVE_OPTIMIZER_TIMING"; + case MetricsType::PLANNER: + return "PLANNER"; + case MetricsType::PLANNER_BINDING: + return "PLANNER_BINDING"; + case MetricsType::PHYSICAL_PLANNER: + return "PHYSICAL_PLANNER"; + case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: + return "PHYSICAL_PLANNER_COLUMN_BINDING"; + case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: + return "PHYSICAL_PLANNER_RESOLVE_TYPES"; + case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: + return "PHYSICAL_PLANNER_CREATE_PLAN"; + case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: + return "OPTIMIZER_EXPRESSION_REWRITER"; + case MetricsType::OPTIMIZER_FILTER_PULLUP: + return "OPTIMIZER_FILTER_PULLUP"; + case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: + return "OPTIMIZER_FILTER_PUSHDOWN"; + case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: + return "OPTIMIZER_CTE_FILTER_PUSHER"; + case MetricsType::OPTIMIZER_REGEX_RANGE: + return "OPTIMIZER_REGEX_RANGE"; + case MetricsType::OPTIMIZER_IN_CLAUSE: + return "OPTIMIZER_IN_CLAUSE"; + case MetricsType::OPTIMIZER_JOIN_ORDER: + return "OPTIMIZER_JOIN_ORDER"; + case MetricsType::OPTIMIZER_DELIMINATOR: + return "OPTIMIZER_DELIMINATOR"; + case MetricsType::OPTIMIZER_UNNEST_REWRITER: + return "OPTIMIZER_UNNEST_REWRITER"; + case MetricsType::OPTIMIZER_UNUSED_COLUMNS: + return "OPTIMIZER_UNUSED_COLUMNS"; + case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: + return "OPTIMIZER_STATISTICS_PROPAGATION"; + case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: + return "OPTIMIZER_COMMON_SUBEXPRESSIONS"; + case MetricsType::OPTIMIZER_COMMON_AGGREGATE: + return "OPTIMIZER_COMMON_AGGREGATE"; + case MetricsType::OPTIMIZER_COLUMN_LIFETIME: + return "OPTIMIZER_COLUMN_LIFETIME"; + case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: + return "OPTIMIZER_BUILD_SIDE_PROBE_SIDE"; + case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: + return "OPTIMIZER_LIMIT_PUSHDOWN"; + case MetricsType::OPTIMIZER_TOP_N: + return "OPTIMIZER_TOP_N"; + case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: + return "OPTIMIZER_COMPRESSED_MATERIALIZATION"; + case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: + return "OPTIMIZER_DUPLICATE_GROUPS"; + case MetricsType::OPTIMIZER_REORDER_FILTER: + return "OPTIMIZER_REORDER_FILTER"; + case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: + return "OPTIMIZER_JOIN_FILTER_PUSHDOWN"; + case MetricsType::OPTIMIZER_EXTENSION: + return "OPTIMIZER_EXTENSION"; + case MetricsType::OPTIMIZER_MATERIALIZED_CTE: + return "OPTIMIZER_MATERIALIZED_CTE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +MetricsType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "QUERY_NAME")) { + return MetricsType::QUERY_NAME; + } + if (StringUtil::Equals(value, "BLOCKED_THREAD_TIME")) { + return MetricsType::BLOCKED_THREAD_TIME; + } + if (StringUtil::Equals(value, "CPU_TIME")) { + return MetricsType::CPU_TIME; + } + if (StringUtil::Equals(value, "EXTRA_INFO")) { + return MetricsType::EXTRA_INFO; + } + if (StringUtil::Equals(value, "CUMULATIVE_CARDINALITY")) { + return MetricsType::CUMULATIVE_CARDINALITY; + } + if (StringUtil::Equals(value, "OPERATOR_TYPE")) { + return MetricsType::OPERATOR_TYPE; + } + if (StringUtil::Equals(value, "OPERATOR_CARDINALITY")) { + return MetricsType::OPERATOR_CARDINALITY; + } + if (StringUtil::Equals(value, "CUMULATIVE_ROWS_SCANNED")) { + return MetricsType::CUMULATIVE_ROWS_SCANNED; + } + if (StringUtil::Equals(value, "OPERATOR_ROWS_SCANNED")) { + return MetricsType::OPERATOR_ROWS_SCANNED; + } + if (StringUtil::Equals(value, "OPERATOR_TIMING")) { + return MetricsType::OPERATOR_TIMING; + } + if (StringUtil::Equals(value, "RESULT_SET_SIZE")) { + return MetricsType::RESULT_SET_SIZE; + } + if (StringUtil::Equals(value, "ALL_OPTIMIZERS")) { + return MetricsType::ALL_OPTIMIZERS; + } + if (StringUtil::Equals(value, "CUMULATIVE_OPTIMIZER_TIMING")) { + return MetricsType::CUMULATIVE_OPTIMIZER_TIMING; + } + if (StringUtil::Equals(value, "PLANNER")) { + return MetricsType::PLANNER; + } + if (StringUtil::Equals(value, "PLANNER_BINDING")) { + return MetricsType::PLANNER_BINDING; + } + if (StringUtil::Equals(value, "PHYSICAL_PLANNER")) { + return MetricsType::PHYSICAL_PLANNER; + } + if (StringUtil::Equals(value, "PHYSICAL_PLANNER_COLUMN_BINDING")) { + return MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING; + } + if (StringUtil::Equals(value, "PHYSICAL_PLANNER_RESOLVE_TYPES")) { + return MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES; + } + if (StringUtil::Equals(value, "PHYSICAL_PLANNER_CREATE_PLAN")) { + return MetricsType::PHYSICAL_PLANNER_CREATE_PLAN; + } + if (StringUtil::Equals(value, "OPTIMIZER_EXPRESSION_REWRITER")) { + return MetricsType::OPTIMIZER_EXPRESSION_REWRITER; + } + if (StringUtil::Equals(value, "OPTIMIZER_FILTER_PULLUP")) { + return MetricsType::OPTIMIZER_FILTER_PULLUP; + } + if (StringUtil::Equals(value, "OPTIMIZER_FILTER_PUSHDOWN")) { + return MetricsType::OPTIMIZER_FILTER_PUSHDOWN; + } + if (StringUtil::Equals(value, "OPTIMIZER_CTE_FILTER_PUSHER")) { + return MetricsType::OPTIMIZER_CTE_FILTER_PUSHER; + } + if (StringUtil::Equals(value, "OPTIMIZER_REGEX_RANGE")) { + return MetricsType::OPTIMIZER_REGEX_RANGE; + } + if (StringUtil::Equals(value, "OPTIMIZER_IN_CLAUSE")) { + return MetricsType::OPTIMIZER_IN_CLAUSE; + } + if (StringUtil::Equals(value, "OPTIMIZER_JOIN_ORDER")) { + return MetricsType::OPTIMIZER_JOIN_ORDER; + } + if (StringUtil::Equals(value, "OPTIMIZER_DELIMINATOR")) { + return MetricsType::OPTIMIZER_DELIMINATOR; + } + if (StringUtil::Equals(value, "OPTIMIZER_UNNEST_REWRITER")) { + return MetricsType::OPTIMIZER_UNNEST_REWRITER; + } + if (StringUtil::Equals(value, "OPTIMIZER_UNUSED_COLUMNS")) { + return MetricsType::OPTIMIZER_UNUSED_COLUMNS; + } + if (StringUtil::Equals(value, "OPTIMIZER_STATISTICS_PROPAGATION")) { + return MetricsType::OPTIMIZER_STATISTICS_PROPAGATION; + } + if (StringUtil::Equals(value, "OPTIMIZER_COMMON_SUBEXPRESSIONS")) { + return MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS; + } + if (StringUtil::Equals(value, "OPTIMIZER_COMMON_AGGREGATE")) { + return MetricsType::OPTIMIZER_COMMON_AGGREGATE; + } + if (StringUtil::Equals(value, "OPTIMIZER_COLUMN_LIFETIME")) { + return MetricsType::OPTIMIZER_COLUMN_LIFETIME; + } + if (StringUtil::Equals(value, "OPTIMIZER_BUILD_SIDE_PROBE_SIDE")) { + return MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE; + } + if (StringUtil::Equals(value, "OPTIMIZER_LIMIT_PUSHDOWN")) { + return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; + } + if (StringUtil::Equals(value, "OPTIMIZER_TOP_N")) { + return MetricsType::OPTIMIZER_TOP_N; + } + if (StringUtil::Equals(value, "OPTIMIZER_COMPRESSED_MATERIALIZATION")) { + return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; + } + if (StringUtil::Equals(value, "OPTIMIZER_DUPLICATE_GROUPS")) { + return MetricsType::OPTIMIZER_DUPLICATE_GROUPS; + } + if (StringUtil::Equals(value, "OPTIMIZER_REORDER_FILTER")) { + return MetricsType::OPTIMIZER_REORDER_FILTER; + } + if (StringUtil::Equals(value, "OPTIMIZER_JOIN_FILTER_PUSHDOWN")) { + return MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN; + } + if (StringUtil::Equals(value, "OPTIMIZER_EXTENSION")) { + return MetricsType::OPTIMIZER_EXTENSION; + } + if (StringUtil::Equals(value, "OPTIMIZER_MATERIALIZED_CTE")) { + return MetricsType::OPTIMIZER_MATERIALIZED_CTE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(NType value) { switch(value) { @@ -4114,6 +4606,12 @@ const char* EnumUtil::ToChars(NType value) { return "NODE_256"; case NType::LEAF_INLINED: return "LEAF_INLINED"; + case NType::NODE_7_LEAF: + return "NODE_7_LEAF"; + case NType::NODE_15_LEAF: + return "NODE_15_LEAF"; + case NType::NODE_256_LEAF: + return "NODE_256_LEAF"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -4142,18 +4640,29 @@ NType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "LEAF_INLINED")) { return NType::LEAF_INLINED; } + if (StringUtil::Equals(value, "NODE_7_LEAF")) { + return NType::NODE_7_LEAF; + } + if (StringUtil::Equals(value, "NODE_15_LEAF")) { + return NType::NODE_15_LEAF; + } + if (StringUtil::Equals(value, "NODE_256_LEAF")) { + return NType::NODE_256_LEAF; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } template<> const char* EnumUtil::ToChars(NewLineIdentifier value) { switch(value) { - case NewLineIdentifier::SINGLE: - return "SINGLE"; + case NewLineIdentifier::SINGLE_N: + return "SINGLE_N"; case NewLineIdentifier::CARRY_ON: return "CARRY_ON"; case NewLineIdentifier::NOT_SET: return "NOT_SET"; + case NewLineIdentifier::SINGLE_R: + return "SINGLE_R"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -4161,8 +4670,8 @@ const char* EnumUtil::ToChars(NewLineIdentifier value) { template<> NewLineIdentifier EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "SINGLE")) { - return NewLineIdentifier::SINGLE; + if (StringUtil::Equals(value, "SINGLE_N")) { + return NewLineIdentifier::SINGLE_N; } if (StringUtil::Equals(value, "CARRY_ON")) { return NewLineIdentifier::CARRY_ON; @@ -4170,6 +4679,9 @@ NewLineIdentifier EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "NOT_SET")) { return NewLineIdentifier::NOT_SET; } + if (StringUtil::Equals(value, "SINGLE_R")) { + return NewLineIdentifier::SINGLE_R; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -4329,6 +4841,8 @@ const char* EnumUtil::ToChars(OptimizerType value) { return "FILTER_PULLUP"; case OptimizerType::FILTER_PUSHDOWN: return "FILTER_PUSHDOWN"; + case OptimizerType::CTE_FILTER_PUSHER: + return "CTE_FILTER_PUSHER"; case OptimizerType::REGEX_RANGE: return "REGEX_RANGE"; case OptimizerType::IN_CLAUSE: @@ -4349,6 +4863,10 @@ const char* EnumUtil::ToChars(OptimizerType value) { return "COMMON_AGGREGATE"; case OptimizerType::COLUMN_LIFETIME: return "COLUMN_LIFETIME"; + case OptimizerType::BUILD_SIDE_PROBE_SIDE: + return "BUILD_SIDE_PROBE_SIDE"; + case OptimizerType::LIMIT_PUSHDOWN: + return "LIMIT_PUSHDOWN"; case OptimizerType::TOP_N: return "TOP_N"; case OptimizerType::COMPRESSED_MATERIALIZATION: @@ -4357,8 +4875,12 @@ const char* EnumUtil::ToChars(OptimizerType value) { return "DUPLICATE_GROUPS"; case OptimizerType::REORDER_FILTER: return "REORDER_FILTER"; + case OptimizerType::JOIN_FILTER_PUSHDOWN: + return "JOIN_FILTER_PUSHDOWN"; case OptimizerType::EXTENSION: return "EXTENSION"; + case OptimizerType::MATERIALIZED_CTE: + return "MATERIALIZED_CTE"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -4378,6 +4900,9 @@ OptimizerType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "FILTER_PUSHDOWN")) { return OptimizerType::FILTER_PUSHDOWN; } + if (StringUtil::Equals(value, "CTE_FILTER_PUSHER")) { + return OptimizerType::CTE_FILTER_PUSHER; + } if (StringUtil::Equals(value, "REGEX_RANGE")) { return OptimizerType::REGEX_RANGE; } @@ -4408,6 +4933,12 @@ OptimizerType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "COLUMN_LIFETIME")) { return OptimizerType::COLUMN_LIFETIME; } + if (StringUtil::Equals(value, "BUILD_SIDE_PROBE_SIDE")) { + return OptimizerType::BUILD_SIDE_PROBE_SIDE; + } + if (StringUtil::Equals(value, "LIMIT_PUSHDOWN")) { + return OptimizerType::LIMIT_PUSHDOWN; + } if (StringUtil::Equals(value, "TOP_N")) { return OptimizerType::TOP_N; } @@ -4420,9 +4951,15 @@ OptimizerType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "REORDER_FILTER")) { return OptimizerType::REORDER_FILTER; } + if (StringUtil::Equals(value, "JOIN_FILTER_PUSHDOWN")) { + return OptimizerType::JOIN_FILTER_PUSHDOWN; + } if (StringUtil::Equals(value, "EXTENSION")) { return OptimizerType::EXTENSION; } + if (StringUtil::Equals(value, "MATERIALIZED_CTE")) { + return OptimizerType::MATERIALIZED_CTE; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -4682,6 +5219,8 @@ const char* EnumUtil::ToChars(PartitionSortStage value) { return "MERGE"; case PartitionSortStage::SORTED: return "SORTED"; + case PartitionSortStage::FINISHED: + return "FINISHED"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -4704,6 +5243,9 @@ PartitionSortStage EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "SORTED")) { return PartitionSortStage::SORTED; } + if (StringUtil::Equals(value, "FINISHED")) { + return PartitionSortStage::FINISHED; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -4771,6 +5313,8 @@ const char* EnumUtil::ToChars(PendingExecutionResult val return "BLOCKED"; case PendingExecutionResult::NO_TASKS_AVAILABLE: return "NO_TASKS_AVAILABLE"; + case PendingExecutionResult::EXECUTION_FINISHED: + return "EXECUTION_FINISHED"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -4793,6 +5337,9 @@ PendingExecutionResult EnumUtil::FromString(const char * if (StringUtil::Equals(value, "NO_TASKS_AVAILABLE")) { return PendingExecutionResult::NO_TASKS_AVAILABLE; } + if (StringUtil::Equals(value, "EXECUTION_FINISHED")) { + return PendingExecutionResult::EXECUTION_FINISHED; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -4937,6 +5484,8 @@ const char* EnumUtil::ToChars(PhysicalOperatorType value) return "EXPORT"; case PhysicalOperatorType::SET: return "SET"; + case PhysicalOperatorType::SET_VARIABLE: + return "SET_VARIABLE"; case PhysicalOperatorType::LOAD: return "LOAD"; case PhysicalOperatorType::INOUT_FUNCTION: @@ -5167,6 +5716,9 @@ PhysicalOperatorType EnumUtil::FromString(const char *valu if (StringUtil::Equals(value, "SET")) { return PhysicalOperatorType::SET; } + if (StringUtil::Equals(value, "SET_VARIABLE")) { + return PhysicalOperatorType::SET_VARIABLE; + } if (StringUtil::Equals(value, "LOAD")) { return PhysicalOperatorType::LOAD; } @@ -5400,6 +5952,8 @@ const char* EnumUtil::ToChars(ProfilerPrintFormat value) { return "JSON"; case ProfilerPrintFormat::QUERY_TREE_OPTIMIZER: return "QUERY_TREE_OPTIMIZER"; + case ProfilerPrintFormat::NO_OUTPUT: + return "NO_OUTPUT"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -5416,6 +5970,9 @@ ProfilerPrintFormat EnumUtil::FromString(const char *value) if (StringUtil::Equals(value, "QUERY_TREE_OPTIMIZER")) { return ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; } + if (StringUtil::Equals(value, "NO_OUTPUT")) { + return ProfilerPrintFormat::NO_OUTPUT; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -5504,6 +6061,8 @@ const char* EnumUtil::ToChars(QueryResultType value) { return "STREAM_RESULT"; case QueryResultType::PENDING_RESULT: return "PENDING_RESULT"; + case QueryResultType::ARROW_RESULT: + return "ARROW_RESULT"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -5520,6 +6079,9 @@ QueryResultType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "PENDING_RESULT")) { return QueryResultType::PENDING_RESULT; } + if (StringUtil::Equals(value, "ARROW_RESULT")) { + return QueryResultType::ARROW_RESULT; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -5606,6 +6168,10 @@ const char* EnumUtil::ToChars(RelationType value) { return "VIEW_RELATION"; case RelationType::QUERY_RELATION: return "QUERY_RELATION"; + case RelationType::DELIM_JOIN_RELATION: + return "DELIM_JOIN_RELATION"; + case RelationType::DELIM_GET_RELATION: + return "DELIM_GET_RELATION"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -5691,6 +6257,12 @@ RelationType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "QUERY_RELATION")) { return RelationType::QUERY_RELATION; } + if (StringUtil::Equals(value, "DELIM_JOIN_RELATION")) { + return RelationType::DELIM_JOIN_RELATION; + } + if (StringUtil::Equals(value, "DELIM_GET_RELATION")) { + return RelationType::DELIM_GET_RELATION; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -5972,6 +6544,8 @@ const char* EnumUtil::ToChars(SetScope value) { return "SESSION"; case SetScope::GLOBAL: return "GLOBAL"; + case SetScope::VARIABLE: + return "VARIABLE"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -5991,6 +6565,9 @@ SetScope EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "GLOBAL")) { return SetScope::GLOBAL; } + if (StringUtil::Equals(value, "VARIABLE")) { + return SetScope::VARIABLE; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -6024,6 +6601,8 @@ const char* EnumUtil::ToChars(SettingScope value) { return "GLOBAL"; case SettingScope::LOCAL: return "LOCAL"; + case SettingScope::SECRET: + return "SECRET"; case SettingScope::INVALID: return "INVALID"; default: @@ -6039,6 +6618,9 @@ SettingScope EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "LOCAL")) { return SettingScope::LOCAL; } + if (StringUtil::Equals(value, "SECRET")) { + return SettingScope::SECRET; + } if (StringUtil::Equals(value, "INVALID")) { return SettingScope::INVALID; } @@ -6706,6 +7288,54 @@ StrTimeSpecifier EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(StreamExecutionResult value) { + switch(value) { + case StreamExecutionResult::CHUNK_READY: + return "CHUNK_READY"; + case StreamExecutionResult::CHUNK_NOT_READY: + return "CHUNK_NOT_READY"; + case StreamExecutionResult::EXECUTION_ERROR: + return "EXECUTION_ERROR"; + case StreamExecutionResult::EXECUTION_CANCELLED: + return "EXECUTION_CANCELLED"; + case StreamExecutionResult::BLOCKED: + return "BLOCKED"; + case StreamExecutionResult::NO_TASKS_AVAILABLE: + return "NO_TASKS_AVAILABLE"; + case StreamExecutionResult::EXECUTION_FINISHED: + return "EXECUTION_FINISHED"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +StreamExecutionResult EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "CHUNK_READY")) { + return StreamExecutionResult::CHUNK_READY; + } + if (StringUtil::Equals(value, "CHUNK_NOT_READY")) { + return StreamExecutionResult::CHUNK_NOT_READY; + } + if (StringUtil::Equals(value, "EXECUTION_ERROR")) { + return StreamExecutionResult::EXECUTION_ERROR; + } + if (StringUtil::Equals(value, "EXECUTION_CANCELLED")) { + return StreamExecutionResult::EXECUTION_CANCELLED; + } + if (StringUtil::Equals(value, "BLOCKED")) { + return StreamExecutionResult::BLOCKED; + } + if (StringUtil::Equals(value, "NO_TASKS_AVAILABLE")) { + return StreamExecutionResult::NO_TASKS_AVAILABLE; + } + if (StringUtil::Equals(value, "EXECUTION_FINISHED")) { + return StreamExecutionResult::EXECUTION_FINISHED; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(SubqueryType value) { switch(value) { @@ -6835,6 +7465,8 @@ const char* EnumUtil::ToChars(TableReferenceType value) { return "SHOW_REF"; case TableReferenceType::COLUMN_DATA: return "COLUMN_DATA"; + case TableReferenceType::DELIM_GET: + return "DELIM_GET"; default: throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); } @@ -6875,6 +7507,9 @@ TableReferenceType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "COLUMN_DATA")) { return TableReferenceType::COLUMN_DATA; } + if (StringUtil::Equals(value, "DELIM_GET")) { + return TableReferenceType::DELIM_GET; + } throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } @@ -7000,6 +7635,34 @@ TimestampCastResult EnumUtil::FromString(const char *value) throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(TransactionModifierType value) { + switch(value) { + case TransactionModifierType::TRANSACTION_DEFAULT_MODIFIER: + return "TRANSACTION_DEFAULT_MODIFIER"; + case TransactionModifierType::TRANSACTION_READ_ONLY: + return "TRANSACTION_READ_ONLY"; + case TransactionModifierType::TRANSACTION_READ_WRITE: + return "TRANSACTION_READ_WRITE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +TransactionModifierType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "TRANSACTION_DEFAULT_MODIFIER")) { + return TransactionModifierType::TRANSACTION_DEFAULT_MODIFIER; + } + if (StringUtil::Equals(value, "TRANSACTION_READ_ONLY")) { + return TransactionModifierType::TRANSACTION_READ_ONLY; + } + if (StringUtil::Equals(value, "TRANSACTION_READ_WRITE")) { + return TransactionModifierType::TRANSACTION_READ_WRITE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(TransactionType value) { switch(value) { @@ -7416,6 +8079,8 @@ const char* EnumUtil::ToChars(WALType value) { return "DELETE_TUPLE"; case WALType::UPDATE_TUPLE: return "UPDATE_TUPLE"; + case WALType::ROW_GROUP_DATA: + return "ROW_GROUP_DATA"; case WALType::WAL_VERSION: return "WAL_VERSION"; case WALType::CHECKPOINT: @@ -7498,6 +8163,9 @@ WALType EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "UPDATE_TUPLE")) { return WALType::UPDATE_TUPLE; } + if (StringUtil::Equals(value, "ROW_GROUP_DATA")) { + return WALType::ROW_GROUP_DATA; + } if (StringUtil::Equals(value, "WAL_VERSION")) { return WALType::WAL_VERSION; } diff --git a/src/duckdb/src/common/enums/file_compression_type.cpp b/src/duckdb/src/common/enums/file_compression_type.cpp index 5df6add1..44066f32 100644 --- a/src/duckdb/src/common/enums/file_compression_type.cpp +++ b/src/duckdb/src/common/enums/file_compression_type.cpp @@ -19,4 +19,28 @@ FileCompressionType FileCompressionTypeFromString(const string &input) { } } +string CompressionExtensionFromType(const FileCompressionType type) { + switch (type) { + case FileCompressionType::GZIP: + return ".gz"; + case FileCompressionType::ZSTD: + return ".zst"; + default: + throw NotImplementedException("Compression Extension of file compression type is not implemented"); + } +} + +bool IsFileCompressed(string path, FileCompressionType type) { + auto extension = CompressionExtensionFromType(type); + std::size_t question_mark_pos = std::string::npos; + if (!StringUtil::StartsWith(path, "\\\\?\\")) { + question_mark_pos = path.find('?'); + } + path = path.substr(0, question_mark_pos); + if (StringUtil::EndsWith(path, extension)) { + return true; + } + return false; +} + } // namespace duckdb diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp new file mode 100644 index 00000000..a317f2d6 --- /dev/null +++ b/src/duckdb/src/common/enums/metric_type.cpp @@ -0,0 +1,208 @@ +//------------------------------------------------------------------------- +// DuckDB +// +// +// duckdb/common/enums/metrics_type.hpp +// +// This file is automatically generated by scripts/generate_metric_enums.py +// Do not edit this file manually, your changes will be overwritten +//------------------------------------------------------------------------- + +#include "duckdb/common/enums/metric_type.hpp" +namespace duckdb { + +profiler_settings_t MetricsUtils::GetOptimizerMetrics() { + return { + MetricsType::OPTIMIZER_EXPRESSION_REWRITER, + MetricsType::OPTIMIZER_FILTER_PULLUP, + MetricsType::OPTIMIZER_FILTER_PUSHDOWN, + MetricsType::OPTIMIZER_CTE_FILTER_PUSHER, + MetricsType::OPTIMIZER_REGEX_RANGE, + MetricsType::OPTIMIZER_IN_CLAUSE, + MetricsType::OPTIMIZER_JOIN_ORDER, + MetricsType::OPTIMIZER_DELIMINATOR, + MetricsType::OPTIMIZER_UNNEST_REWRITER, + MetricsType::OPTIMIZER_UNUSED_COLUMNS, + MetricsType::OPTIMIZER_STATISTICS_PROPAGATION, + MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS, + MetricsType::OPTIMIZER_COMMON_AGGREGATE, + MetricsType::OPTIMIZER_COLUMN_LIFETIME, + MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, + MetricsType::OPTIMIZER_LIMIT_PUSHDOWN, + MetricsType::OPTIMIZER_TOP_N, + MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION, + MetricsType::OPTIMIZER_DUPLICATE_GROUPS, + MetricsType::OPTIMIZER_REORDER_FILTER, + MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN, + MetricsType::OPTIMIZER_EXTENSION, + MetricsType::OPTIMIZER_MATERIALIZED_CTE, + }; +} + +profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { + return { + MetricsType::ALL_OPTIMIZERS, + MetricsType::CUMULATIVE_OPTIMIZER_TIMING, + MetricsType::PLANNER, + MetricsType::PLANNER_BINDING, + MetricsType::PHYSICAL_PLANNER, + MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING, + MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, + MetricsType::PHYSICAL_PLANNER_CREATE_PLAN, + }; +} + +MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { + switch(type) { + case OptimizerType::EXPRESSION_REWRITER: + return MetricsType::OPTIMIZER_EXPRESSION_REWRITER; + case OptimizerType::FILTER_PULLUP: + return MetricsType::OPTIMIZER_FILTER_PULLUP; + case OptimizerType::FILTER_PUSHDOWN: + return MetricsType::OPTIMIZER_FILTER_PUSHDOWN; + case OptimizerType::CTE_FILTER_PUSHER: + return MetricsType::OPTIMIZER_CTE_FILTER_PUSHER; + case OptimizerType::REGEX_RANGE: + return MetricsType::OPTIMIZER_REGEX_RANGE; + case OptimizerType::IN_CLAUSE: + return MetricsType::OPTIMIZER_IN_CLAUSE; + case OptimizerType::JOIN_ORDER: + return MetricsType::OPTIMIZER_JOIN_ORDER; + case OptimizerType::DELIMINATOR: + return MetricsType::OPTIMIZER_DELIMINATOR; + case OptimizerType::UNNEST_REWRITER: + return MetricsType::OPTIMIZER_UNNEST_REWRITER; + case OptimizerType::UNUSED_COLUMNS: + return MetricsType::OPTIMIZER_UNUSED_COLUMNS; + case OptimizerType::STATISTICS_PROPAGATION: + return MetricsType::OPTIMIZER_STATISTICS_PROPAGATION; + case OptimizerType::COMMON_SUBEXPRESSIONS: + return MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS; + case OptimizerType::COMMON_AGGREGATE: + return MetricsType::OPTIMIZER_COMMON_AGGREGATE; + case OptimizerType::COLUMN_LIFETIME: + return MetricsType::OPTIMIZER_COLUMN_LIFETIME; + case OptimizerType::BUILD_SIDE_PROBE_SIDE: + return MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE; + case OptimizerType::LIMIT_PUSHDOWN: + return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; + case OptimizerType::TOP_N: + return MetricsType::OPTIMIZER_TOP_N; + case OptimizerType::COMPRESSED_MATERIALIZATION: + return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; + case OptimizerType::DUPLICATE_GROUPS: + return MetricsType::OPTIMIZER_DUPLICATE_GROUPS; + case OptimizerType::REORDER_FILTER: + return MetricsType::OPTIMIZER_REORDER_FILTER; + case OptimizerType::JOIN_FILTER_PUSHDOWN: + return MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN; + case OptimizerType::EXTENSION: + return MetricsType::OPTIMIZER_EXTENSION; + case OptimizerType::MATERIALIZED_CTE: + return MetricsType::OPTIMIZER_MATERIALIZED_CTE; + default: + throw InternalException("OptimizerType %s cannot be converted to a MetricsType", EnumUtil::ToString(type)); + }; +} + +OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { + switch(type) { + case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: + return OptimizerType::EXPRESSION_REWRITER; + case MetricsType::OPTIMIZER_FILTER_PULLUP: + return OptimizerType::FILTER_PULLUP; + case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: + return OptimizerType::FILTER_PUSHDOWN; + case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: + return OptimizerType::CTE_FILTER_PUSHER; + case MetricsType::OPTIMIZER_REGEX_RANGE: + return OptimizerType::REGEX_RANGE; + case MetricsType::OPTIMIZER_IN_CLAUSE: + return OptimizerType::IN_CLAUSE; + case MetricsType::OPTIMIZER_JOIN_ORDER: + return OptimizerType::JOIN_ORDER; + case MetricsType::OPTIMIZER_DELIMINATOR: + return OptimizerType::DELIMINATOR; + case MetricsType::OPTIMIZER_UNNEST_REWRITER: + return OptimizerType::UNNEST_REWRITER; + case MetricsType::OPTIMIZER_UNUSED_COLUMNS: + return OptimizerType::UNUSED_COLUMNS; + case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: + return OptimizerType::STATISTICS_PROPAGATION; + case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: + return OptimizerType::COMMON_SUBEXPRESSIONS; + case MetricsType::OPTIMIZER_COMMON_AGGREGATE: + return OptimizerType::COMMON_AGGREGATE; + case MetricsType::OPTIMIZER_COLUMN_LIFETIME: + return OptimizerType::COLUMN_LIFETIME; + case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: + return OptimizerType::BUILD_SIDE_PROBE_SIDE; + case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: + return OptimizerType::LIMIT_PUSHDOWN; + case MetricsType::OPTIMIZER_TOP_N: + return OptimizerType::TOP_N; + case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: + return OptimizerType::COMPRESSED_MATERIALIZATION; + case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: + return OptimizerType::DUPLICATE_GROUPS; + case MetricsType::OPTIMIZER_REORDER_FILTER: + return OptimizerType::REORDER_FILTER; + case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: + return OptimizerType::JOIN_FILTER_PUSHDOWN; + case MetricsType::OPTIMIZER_EXTENSION: + return OptimizerType::EXTENSION; + case MetricsType::OPTIMIZER_MATERIALIZED_CTE: + return OptimizerType::MATERIALIZED_CTE; + default: + return OptimizerType::INVALID; + }; +} + +bool MetricsUtils::IsOptimizerMetric(MetricsType type) { + switch(type) { + case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: + case MetricsType::OPTIMIZER_FILTER_PULLUP: + case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: + case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: + case MetricsType::OPTIMIZER_REGEX_RANGE: + case MetricsType::OPTIMIZER_IN_CLAUSE: + case MetricsType::OPTIMIZER_JOIN_ORDER: + case MetricsType::OPTIMIZER_DELIMINATOR: + case MetricsType::OPTIMIZER_UNNEST_REWRITER: + case MetricsType::OPTIMIZER_UNUSED_COLUMNS: + case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: + case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: + case MetricsType::OPTIMIZER_COMMON_AGGREGATE: + case MetricsType::OPTIMIZER_COLUMN_LIFETIME: + case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: + case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: + case MetricsType::OPTIMIZER_TOP_N: + case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: + case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: + case MetricsType::OPTIMIZER_REORDER_FILTER: + case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: + case MetricsType::OPTIMIZER_EXTENSION: + case MetricsType::OPTIMIZER_MATERIALIZED_CTE: + return true; + default: + return false; + }; +} + +bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { + switch(type) { + case MetricsType::ALL_OPTIMIZERS: + case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: + case MetricsType::PLANNER: + case MetricsType::PLANNER_BINDING: + case MetricsType::PHYSICAL_PLANNER: + case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: + case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: + case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: + return true; + default: + return false; + }; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index d3f13108..460bbb3a 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -1,7 +1,8 @@ #include "duckdb/common/enums/optimizer_type.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/exception/parser_exception.hpp" + #include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/parser_exception.hpp" +#include "duckdb/common/string_util.hpp" namespace duckdb { @@ -14,6 +15,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"expression_rewriter", OptimizerType::EXPRESSION_REWRITER}, {"filter_pullup", OptimizerType::FILTER_PULLUP}, {"filter_pushdown", OptimizerType::FILTER_PUSHDOWN}, + {"cte_filter_pusher", OptimizerType::CTE_FILTER_PUSHER}, {"regex_range", OptimizerType::REGEX_RANGE}, {"in_clause", OptimizerType::IN_CLAUSE}, {"join_order", OptimizerType::JOIN_ORDER}, @@ -24,11 +26,15 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"common_subexpressions", OptimizerType::COMMON_SUBEXPRESSIONS}, {"common_aggregate", OptimizerType::COMMON_AGGREGATE}, {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, + {"limit_pushdown", OptimizerType::LIMIT_PUSHDOWN}, {"top_n", OptimizerType::TOP_N}, + {"build_side_probe_side", OptimizerType::BUILD_SIDE_PROBE_SIDE}, {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, {"reorder_filter", OptimizerType::REORDER_FILTER}, + {"join_filter_pushdown", OptimizerType::JOIN_FILTER_PUSHDOWN}, {"extension", OptimizerType::EXTENSION}, + {"materialized_cte", OptimizerType::MATERIALIZED_CTE}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/enums/physical_operator_type.cpp b/src/duckdb/src/common/enums/physical_operator_type.cpp index 4cb851c9..a48f4407 100644 --- a/src/duckdb/src/common/enums/physical_operator_type.cpp +++ b/src/duckdb/src/common/enums/physical_operator_type.cpp @@ -133,6 +133,8 @@ string PhysicalOperatorToString(PhysicalOperatorType type) { return "EXPORT"; case PhysicalOperatorType::SET: return "SET"; + case PhysicalOperatorType::SET_VARIABLE: + return "SET_VARIABLE"; case PhysicalOperatorType::RESET: return "RESET"; case PhysicalOperatorType::LOAD: diff --git a/src/duckdb/src/common/enums/relation_type.cpp b/src/duckdb/src/common/enums/relation_type.cpp index caac469a..4f58ed7c 100644 --- a/src/duckdb/src/common/enums/relation_type.cpp +++ b/src/duckdb/src/common/enums/relation_type.cpp @@ -9,6 +9,10 @@ string RelationTypeToString(RelationType type) { switch (type) { case RelationType::TABLE_RELATION: return "TABLE_RELATION"; + case RelationType::DELIM_GET_RELATION: + return "DELIM_GET_RELATION"; + case RelationType::DELIM_JOIN_RELATION: + return "DELIM_JOIN_RELATION"; case RelationType::PROJECTION_RELATION: return "PROJECTION_RELATION"; case RelationType::FILTER_RELATION: diff --git a/src/duckdb/src/common/enums/statement_type.cpp b/src/duckdb/src/common/enums/statement_type.cpp index 0250ff2c..98524b93 100644 --- a/src/duckdb/src/common/enums/statement_type.cpp +++ b/src/duckdb/src/common/enums/statement_type.cpp @@ -1,5 +1,7 @@ #include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/catalog/catalog.hpp" + namespace duckdb { // LCOV_EXCL_START @@ -82,4 +84,17 @@ string StatementReturnTypeToString(StatementReturnType type) { } // LCOV_EXCL_STOP +void StatementProperties::RegisterDBRead(Catalog &catalog, ClientContext &context) { + auto catalog_identity = CatalogIdentity {catalog.GetOid(), catalog.GetCatalogVersion(context)}; + D_ASSERT(read_databases.count(catalog.GetName()) == 0 || read_databases[catalog.GetName()] == catalog_identity); + read_databases[catalog.GetName()] = catalog_identity; +} + +void StatementProperties::RegisterDBModify(Catalog &catalog, ClientContext &context) { + auto catalog_identity = CatalogIdentity {catalog.GetOid(), catalog.GetCatalogVersion(context)}; + D_ASSERT(modified_databases.count(catalog.GetName()) == 0 || + modified_databases[catalog.GetName()] == catalog_identity); + modified_databases[catalog.GetName()] = catalog_identity; +} + } // namespace duckdb diff --git a/src/duckdb/src/common/error_data.cpp b/src/duckdb/src/common/error_data.cpp index c7262cab..d0a427cc 100644 --- a/src/duckdb/src/common/error_data.cpp +++ b/src/duckdb/src/common/error_data.cpp @@ -17,10 +17,12 @@ ErrorData::ErrorData(const std::exception &ex) : ErrorData(ex.what()) { } ErrorData::ErrorData(ExceptionType type, const string &message) - : initialized(true), type(type), raw_message(SanitizeErrorMessage(message)) { + : initialized(true), type(type), raw_message(SanitizeErrorMessage(message)), + final_message(ConstructFinalMessage()) { } -ErrorData::ErrorData(const string &message) : initialized(true), type(ExceptionType::INVALID), raw_message(string()) { +ErrorData::ErrorData(const string &message) + : initialized(true), type(ExceptionType::INVALID), raw_message(string()), final_message(string()) { // parse the constructed JSON if (message.empty() || message[0] != '{') { @@ -29,11 +31,9 @@ ErrorData::ErrorData(const string &message) : initialized(true), type(ExceptionT if (message == std::bad_alloc().what()) { type = ExceptionType::OUT_OF_MEMORY; raw_message = "Allocation failure"; - return; + } else { + raw_message = message; } - - raw_message = message; - return; } else { auto info = StringUtil::ParseJSONMap(message); for (auto &entry : info) { @@ -46,27 +46,28 @@ ErrorData::ErrorData(const string &message) : initialized(true), type(ExceptionT } } } -} -const string &ErrorData::Message() { - if (final_message.empty()) { - if (type != ExceptionType::UNKNOWN_TYPE) { - final_message = Exception::ExceptionTypeToString(type) + " "; - } - final_message += "Error: " + raw_message; - if (type == ExceptionType::INTERNAL) { - final_message += "\nThis error signals an assertion failure within DuckDB. This usually occurs due to " - "unexpected conditions or errors in the program's logic.\nFor more information, see " - "https://duckdb.org/docs/dev/internal_errors"; - } - } - return final_message; + final_message = ConstructFinalMessage(); } string ErrorData::SanitizeErrorMessage(string error) { return StringUtil::Replace(std::move(error), string("\0", 1), "\\0"); } +string ErrorData::ConstructFinalMessage() const { + std::string error; + if (type != ExceptionType::UNKNOWN_TYPE) { + error = Exception::ExceptionTypeToString(type) + " "; + } + error += "Error: " + raw_message; + if (type == ExceptionType::INTERNAL) { + error += "\nThis error signals an assertion failure within DuckDB. This usually occurs due to " + "unexpected conditions or errors in the program's logic.\nFor more information, see " + "https://duckdb.org/docs/dev/internal_errors"; + } + return error; +} + void ErrorData::Throw(const string &prepended_message) const { D_ASSERT(initialized); if (!prepended_message.empty()) { @@ -107,6 +108,7 @@ void ErrorData::AddErrorLocation(const string &query) { return; } raw_message = QueryErrorContext::Format(query, raw_message, std::stoull(entry->second)); + final_message = ConstructFinalMessage(); } void ErrorData::AddQueryLocation(optional_idx query_location) { diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp index 1ad9c751..b8aac720 100644 --- a/src/duckdb/src/common/exception.cpp +++ b/src/duckdb/src/common/exception.cpp @@ -160,7 +160,8 @@ static constexpr ExceptionEntry EXCEPTION_MAP[] = {{ExceptionType::INVALID, "Inv {ExceptionType::MISSING_EXTENSION, "Missing Extension"}, {ExceptionType::HTTP, "HTTP"}, {ExceptionType::AUTOLOAD, "Extension Autoloading"}, - {ExceptionType::SEQUENCE, "Sequence"}}; + {ExceptionType::SEQUENCE, "Sequence"}, + {ExceptionType::INVALID_CONFIGURATION, "Invalid Configuration"}}; string Exception::ExceptionTypeToString(ExceptionType type) { for (auto &e : EXCEPTION_MAP) { @@ -340,6 +341,15 @@ InvalidInputException::InvalidInputException(const string &msg, const unordered_ : Exception(ExceptionType::INVALID_INPUT, msg, extra_info) { } +InvalidConfigurationException::InvalidConfigurationException(const string &msg) + : Exception(ExceptionType::INVALID_CONFIGURATION, msg) { +} + +InvalidConfigurationException::InvalidConfigurationException(const string &msg, + const unordered_map &extra_info) + : Exception(ExceptionType::INVALID_CONFIGURATION, msg, extra_info) { +} + OutOfMemoryException::OutOfMemoryException(const string &msg) : Exception(ExceptionType::OUT_OF_MEMORY, msg) { } diff --git a/src/duckdb/src/common/exception/binder_exception.cpp b/src/duckdb/src/common/exception/binder_exception.cpp index 458c563c..55db8386 100644 --- a/src/duckdb/src/common/exception/binder_exception.cpp +++ b/src/duckdb/src/common/exception/binder_exception.cpp @@ -44,4 +44,9 @@ BinderException BinderException::NoMatchingFunction(const string &name, const ve extra_info); } +BinderException BinderException::Unsupported(ParsedExpression &expr, const string &message) { + auto extra_info = Exception::InitializeExtraInfo("UNSUPPORTED", expr.query_location); + return BinderException(message, extra_info); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index 88dacaf3..6c09480c 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -428,6 +428,9 @@ IntegerLiteralTypeInfo::IntegerLiteralTypeInfo() : ExtraTypeInfo(ExtraTypeInfoTy IntegerLiteralTypeInfo::IntegerLiteralTypeInfo(Value constant_value_p) : ExtraTypeInfo(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), constant_value(std::move(constant_value_p)) { + if (constant_value.IsNull()) { + throw InternalException("Integer literal cannot be NULL"); + } } bool IntegerLiteralTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp index 01fa5d77..7cde4c6f 100644 --- a/src/duckdb/src/common/file_buffer.cpp +++ b/src/duckdb/src/common/file_buffer.cpp @@ -67,7 +67,7 @@ FileBuffer::MemoryRequirement FileBuffer::CalculateMemory(uint64_t user_size) { result.header_size = 0; result.alloc_size = user_size; } else { - result.header_size = Storage::BLOCK_HEADER_SIZE; + result.header_size = Storage::DEFAULT_BLOCK_HEADER_SIZE; result.alloc_size = AlignValue(result.header_size + user_size); } return result; diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 27160adc..6e68402b 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,8 @@ constexpr FileOpenFlags FileFlags::FILE_FLAGS_APPEND; constexpr FileOpenFlags FileFlags::FILE_FLAGS_PRIVATE; constexpr FileOpenFlags FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS; constexpr FileOpenFlags FileFlags::FILE_FLAGS_PARALLEL_ACCESS; +constexpr FileOpenFlags FileFlags::FILE_FLAGS_EXCLUSIVE_CREATE; +constexpr FileOpenFlags FileFlags::FILE_FLAGS_NULL_IF_EXISTS; void FileOpenFlags::Verify() { #ifdef DEBUG @@ -65,6 +68,8 @@ void FileOpenFlags::Verify() { (flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE) || (flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE_NEW); bool is_private = (flags & FileOpenFlags::FILE_FLAGS_PRIVATE); bool null_if_not_exists = flags & FileOpenFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS; + bool exclusive_create = flags & FileOpenFlags::FILE_FLAGS_EXCLUSIVE_CREATE; + bool null_if_exists = flags & FileOpenFlags::FILE_FLAGS_NULL_IF_EXISTS; // require either READ or WRITE (or both) D_ASSERT(is_read || is_write); @@ -79,6 +84,10 @@ void FileOpenFlags::Verify() { D_ASSERT(!is_private || is_create); // FILE_FLAGS_NULL_IF_NOT_EXISTS cannot be combined with CREATE/CREATE_NEW D_ASSERT(!(null_if_not_exists && is_create)); + // FILE_FLAGS_EXCLUSIVE_CREATE only can be combined with CREATE/CREATE_NEW + D_ASSERT(!exclusive_create || is_create); + // FILE_FLAGS_NULL_IF_EXISTS only can be set with EXCLUSIVE_CREATE + D_ASSERT(!null_if_exists || exclusive_create); #endif } @@ -584,6 +593,10 @@ bool FileHandle::CanSeek() { return file_system.CanSeek(); } +FileCompressionType FileHandle::GetFileCompressionType() { + return FileCompressionType::UNCOMPRESSED; +} + bool FileHandle::IsPipe() { return file_system.IsPipe(path); } @@ -622,10 +635,19 @@ FileType FileHandle::GetType() { return file_system.GetFileType(*this); } +idx_t FileHandle::GetProgress() { + throw NotImplementedException("GetProgress is not implemented for this file handle"); +} + bool FileSystem::IsRemoteFile(const string &path) { - const string prefixes[] = {"http://", "https://", "s3://", "s3a://", "s3n://", "gcs://", "gs://", "r2://", "hf://"}; - for (auto &prefix : prefixes) { - if (StringUtil::StartsWith(path, prefix)) { + string extension = ""; + return IsRemoteFile(path, extension); +} + +bool FileSystem::IsRemoteFile(const string &path, string &extension) { + for (const auto &entry : EXTENSION_FILE_PREFIXES) { + if (StringUtil::StartsWith(path, entry.name)) { + extension = entry.extension; return true; } } diff --git a/src/duckdb/src/common/filename_pattern.cpp b/src/duckdb/src/common/filename_pattern.cpp index 89cc6c63..04851ad3 100644 --- a/src/duckdb/src/common/filename_pattern.cpp +++ b/src/duckdb/src/common/filename_pattern.cpp @@ -10,6 +10,7 @@ void FilenamePattern::SetFilenamePattern(const string &pattern) { base = pattern; pos = base.find(id_format); + uuid = false; if (pos != string::npos) { base = StringUtil::Replace(base, id_format, ""); uuid = false; diff --git a/src/duckdb/src/common/fsst.cpp b/src/duckdb/src/common/fsst.cpp index 6c8c3de4..74b747de 100644 --- a/src/duckdb/src/common/fsst.cpp +++ b/src/duckdb/src/common/fsst.cpp @@ -6,30 +6,31 @@ namespace duckdb { string_t FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, Vector &result, const char *compressed_string, - idx_t compressed_string_len) { + const idx_t compressed_string_len, vector &decompress_buffer) { + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT - auto decompressed_string_size = - duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); - D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); + auto decompressed_string_size = duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, + decompress_buffer.size(), decompress_buffer.data()); - return StringVector::AddStringOrBlob(result, const_char_ptr_cast(decompress_buffer), decompressed_string_size); + D_ASSERT(!decompress_buffer.empty()); + D_ASSERT(decompressed_string_size <= decompress_buffer.size() - 1); + return StringVector::AddStringOrBlob(result, const_char_ptr_cast(decompress_buffer.data()), + decompressed_string_size); } Value FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, const char *compressed_string, - idx_t compressed_string_len) { - unsigned char decompress_buffer[StringUncompressed::STRING_BLOCK_LIMIT + 1]; + const idx_t compressed_string_len, vector &decompress_buffer) { + auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); - auto decompressed_string_size = - duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - StringUncompressed::STRING_BLOCK_LIMIT + 1, &decompress_buffer[0]); - D_ASSERT(decompressed_string_size <= StringUncompressed::STRING_BLOCK_LIMIT); + auto decompressed_string_size = duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, + decompress_buffer.size(), decompress_buffer.data()); - return Value(string(char_ptr_cast(decompress_buffer), decompressed_string_size)); + D_ASSERT(!decompress_buffer.empty()); + D_ASSERT(decompressed_string_size <= decompress_buffer.size() - 1); + return Value(string(char_ptr_cast(decompress_buffer.data()), decompressed_string_size)); } } // namespace duckdb diff --git a/src/duckdb/src/common/gzip_file_system.cpp b/src/duckdb/src/common/gzip_file_system.cpp index 100721a3..ee0a2158 100644 --- a/src/duckdb/src/common/gzip_file_system.cpp +++ b/src/duckdb/src/common/gzip_file_system.cpp @@ -300,7 +300,9 @@ class GZipFile : public CompressedFile { : CompressedFile(gzip_fs, std::move(child_handle_p), path) { Initialize(write); } - + FileCompressionType GetFileCompressionType() override { + return FileCompressionType::GZIP; + } GZipFileSystem gzip_fs; }; diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp index 3ae1f7f3..8ae9b685 100644 --- a/src/duckdb/src/common/hive_partitioning.cpp +++ b/src/duckdb/src/common/hive_partitioning.cpp @@ -2,36 +2,44 @@ #include "duckdb/common/uhugeint.hpp" #include "duckdb/execution/expression_executor.hpp" -#include "duckdb/optimizer/filter_combiner.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/table_filter.hpp" -#include "re2/re2.h" +#include "duckdb/common/multi_file_list.hpp" namespace duckdb { -static unordered_map GetKnownColumnValues(string &filename, - unordered_map &column_map, - duckdb_re2::RE2 &compiled_regex, bool filename_col, - bool hive_partition_cols) { - unordered_map result; +struct PartitioningColumnValue { + explicit PartitioningColumnValue(string value_p) : value(std::move(value_p)) { + } + PartitioningColumnValue(string key_p, string value_p) : key(std::move(key_p)), value(std::move(value_p)) { + } + + string key; + string value; +}; - if (filename_col) { +static unordered_map +GetKnownColumnValues(const string &filename, const HivePartitioningFilterInfo &filter_info) { + unordered_map result; + + auto &column_map = filter_info.column_map; + if (filter_info.filename_enabled) { auto lookup_column_id = column_map.find("filename"); if (lookup_column_id != column_map.end()) { - result[lookup_column_id->second] = filename; + result.insert(make_pair(lookup_column_id->second, PartitioningColumnValue(filename))); } } - if (hive_partition_cols) { - auto partitions = HivePartitioning::Parse(filename, compiled_regex); + if (filter_info.hive_enabled) { + auto partitions = HivePartitioning::Parse(filename); for (auto &partition : partitions) { auto lookup_column_id = column_map.find(partition.first); if (lookup_column_id != column_map.end()) { - result[lookup_column_id->second] = partition.second; + result.insert( + make_pair(lookup_column_id->second, PartitioningColumnValue(partition.first, partition.second))); } } } @@ -40,8 +48,9 @@ static unordered_map GetKnownColumnValues(string &filename, } // Takes an expression and converts a list of known column_refs to constants -static void ConvertKnownColRefToConstants(unique_ptr &expr, - unordered_map &known_column_values, idx_t table_index) { +static void ConvertKnownColRefToConstants(ClientContext &context, unique_ptr &expr, + const unordered_map &known_column_values, + idx_t table_index) { if (expr->type == ExpressionType::BOUND_COLUMN_REF) { auto &bound_colref = expr->Cast(); @@ -52,70 +61,118 @@ static void ConvertKnownColRefToConstants(unique_ptr &expr, auto lookup = known_column_values.find(bound_colref.binding.column_index); if (lookup != known_column_values.end()) { - expr = make_uniq(Value(lookup->second).DefaultCastAs(bound_colref.return_type)); + auto &partition_val = lookup->second; + Value result_val; + if (partition_val.key.empty()) { + // filename column - use directly + result_val = Value(partition_val.value); + } else { + // hive partitioning column - cast the value to the target type + result_val = HivePartitioning::GetValue(context, partition_val.key, partition_val.value, + bound_colref.return_type); + } + expr = make_uniq(std::move(result_val)); } } else { ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - ConvertKnownColRefToConstants(child, known_column_values, table_index); + ConvertKnownColRefToConstants(context, child, known_column_values, table_index); }); } } +string HivePartitioning::Escape(const string &input) { + return StringUtil::URLEncode(input); +} + +string HivePartitioning::Unescape(const string &input) { + return StringUtil::URLDecode(input); +} + // matches hive partitions in file name. For example: // - s3://bucket/var1=value1/bla/bla/var2=value2 // - http(s)://domain(:port)/lala/kasdl/var1=value1/?not-a-var=not-a-value // - folder/folder/folder/../var1=value1/etc/.//var2=value2 -const string &HivePartitioning::RegexString() { - static string REGEX = "[\\/\\\\]([^\\/\\?\\\\]+)=([^\\/\\n\\?\\\\]*)"; - return REGEX; -} - -std::map HivePartitioning::Parse(const string &filename, duckdb_re2::RE2 ®ex) { +std::map HivePartitioning::Parse(const string &filename) { + idx_t partition_start = 0; + idx_t equality_sign = 0; + bool candidate_partition = true; std::map result; - duckdb_re2::StringPiece input(filename); // Wrap a StringPiece around it - - string var; - string value; - while (RE2::FindAndConsume(&input, regex, &var, &value)) { - result.insert(std::pair(var, value)); + for (idx_t c = 0; c < filename.size(); c++) { + if (filename[c] == '?' || filename[c] == '\n') { + // get parameter or newline - not a partition + candidate_partition = false; + } + if (filename[c] == '\\' || filename[c] == '/') { + // separator + if (candidate_partition && equality_sign > partition_start) { + // we found a partition with an equality sign + string key = filename.substr(partition_start, equality_sign - partition_start); + string value = filename.substr(equality_sign + 1, c - equality_sign - 1); + result.insert(make_pair(std::move(key), std::move(value))); + } + partition_start = c + 1; + candidate_partition = true; + } else if (filename[c] == '=') { + if (equality_sign > partition_start) { + // multiple equality signs - not a partition + candidate_partition = false; + } + equality_sign = c; + } } return result; } -std::map HivePartitioning::Parse(const string &filename) { - duckdb_re2::RE2 regex(RegexString()); - return Parse(filename, regex); +Value HivePartitioning::GetValue(ClientContext &context, const string &key, const string &str_val, + const LogicalType &type) { + // Handle nulls + if (StringUtil::CIEquals(str_val, "NULL")) { + return Value(type); + } + if (type.id() == LogicalTypeId::VARCHAR) { + // for string values we can directly return the type + return Value(Unescape(str_val)); + } + if (str_val.empty()) { + // empty strings are NULL for non-string types + return Value(type); + } + + // cast to the target type + Value value(Unescape(str_val)); + if (!value.TryCastAs(context, type)) { + throw InvalidInputException("Unable to cast '%s' (from hive partition column '%s') to: '%s'", value.ToString(), + StringUtil::Upper(key), type.ToString()); + } + return value; } // TODO: this can still be improved by removing the parts of filter expressions that are true for all remaining files. // currently, only expressions that cannot be evaluated during pushdown are removed. void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector &files, vector> &filters, - unordered_map &column_map, LogicalGet &get, - bool hive_enabled, bool filename_enabled) { + const HivePartitioningFilterInfo &filter_info, + MultiFilePushdownInfo &info) { vector pruned_files; vector have_preserved_filter(filters.size(), false); vector> pruned_filters; unordered_set filters_applied_to_files; - duckdb_re2::RE2 regex(RegexString()); - auto table_index = get.table_index; + auto table_index = info.table_index; - if ((!filename_enabled && !hive_enabled) || filters.empty()) { + if ((!filter_info.filename_enabled && !filter_info.hive_enabled) || filters.empty()) { return; } for (idx_t i = 0; i < files.size(); i++) { auto &file = files[i]; bool should_prune_file = false; - auto known_values = GetKnownColumnValues(file, column_map, regex, filename_enabled, hive_enabled); - - FilterCombiner combiner(context); + auto known_values = GetKnownColumnValues(file, filter_info); for (idx_t j = 0; j < filters.size(); j++) { auto &filter = filters[j]; unique_ptr filter_copy = filter->Copy(); - ConvertKnownColRefToConstants(filter_copy, known_values, table_index); + ConvertKnownColRefToConstants(context, filter_copy, known_values, table_index); // Evaluate the filter, if it can be evaluated here, we can not prune this filter Value result_value; @@ -126,12 +183,12 @@ void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vectorCopy()); have_preserved_filter[j] = true; } - } else if (!result_value.GetValue()) { + } else if (result_value.IsNull() || !result_value.GetValue()) { // filter evaluates to false should_prune_file = true; // convert the filter to a table filter. if (filters_applied_to_files.find(j) == filters_applied_to_files.end()) { - get.extra_info.file_filters += filter->ToString(); + info.extra_info.file_filters += filter->ToString(); filters_applied_to_files.insert(j); } } @@ -144,6 +201,9 @@ void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector= pruned_filters.size()); + info.extra_info.total_files = files.size(); + info.extra_info.filtered_files = pruned_files.size(); + filters = std::move(pruned_filters); files = std::move(pruned_files); } diff --git a/src/duckdb/src/common/http_util.cpp b/src/duckdb/src/common/http_util.cpp new file mode 100644 index 00000000..4c1c73d1 --- /dev/null +++ b/src/duckdb/src/common/http_util.cpp @@ -0,0 +1,25 @@ +#include "duckdb/common/http_util.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +void HTTPUtil::ParseHTTPProxyHost(string &proxy_value, string &hostname_out, idx_t &port_out, idx_t default_port) { + auto proxy_split = StringUtil::Split(proxy_value, ":"); + if (proxy_split.size() == 1) { + hostname_out = proxy_split[0]; + port_out = default_port; + } else if (proxy_split.size() == 2) { + idx_t port; + if (!TryCast::Operation(proxy_split[1], port, false)) { + throw InvalidInputException("Failed to parse port from http_proxy '%s'", proxy_value); + } + hostname_out = proxy_split[0]; + port_out = port; + } else { + throw InvalidInputException("Failed to parse http_proxy '%s' into a host and port", proxy_value); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index e650d79b..70246d26 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -29,6 +29,7 @@ #ifdef __MINGW32__ // need to manually define this for mingw extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); +extern "C" WINBASEAPI BOOL QueryFullProcessImageNameW(HANDLE, DWORD, LPWSTR, PDWORD); #endif #undef FILE_CREATE // woo mingw @@ -332,6 +333,10 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF filesec = 0666; } + if (flags.ExclusiveCreate()) { + open_flags |= O_EXCL; + } + // Open the file int fd = open(path.c_str(), open_flags, filesec); @@ -339,6 +344,9 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (flags.ReturnNullIfNotExists() && errno == ENOENT) { return nullptr; } + if (flags.ReturnNullIfExists() && errno == EEXIST) { + return nullptr; + } throw IOException("Cannot open file \"%s\": %s", {{"errno", std::to_string(errno)}}, path, strerror(errno)); } // #if defined(__DARWIN__) || defined(__APPLE__) @@ -364,32 +372,48 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF rc = fcntl(fd, F_SETLK, &fl); // Retain the original error. int retained_errno = errno; - if (rc == -1) { - string message; - // try to find out who is holding the lock using F_GETLK - rc = fcntl(fd, F_GETLK, &fl); - if (rc == -1) { // fnctl does not want to help us - message = strerror(errno); - } else { - message = AdditionalProcessInfo(*this, fl.l_pid); + bool has_error = rc == -1; + string extended_error; + if (has_error) { + if (retained_errno == ENOTSUP) { + // file lock not supported for this file system + if (flags.Lock() == FileLockType::READ_LOCK) { + // for read-only, we ignore not-supported errors + has_error = false; + errno = 0; + } else { + extended_error = "File locks are not supported for this file system, cannot open the file in " + "read-write mode. Try opening the file in read-only mode"; + } } - - if (flags.Lock() == FileLockType::WRITE_LOCK) { - // maybe we can get a read lock instead and tell this to the user. - fl.l_type = F_RDLCK; - rc = fcntl(fd, F_SETLK, &fl); - if (rc != -1) { // success! - message += ". However, you would be able to open this database in read-only mode, e.g. by " - "using the -readonly parameter in the CLI"; + } + if (has_error) { + if (extended_error.empty()) { + // try to find out who is holding the lock using F_GETLK + rc = fcntl(fd, F_GETLK, &fl); + if (rc == -1) { // fnctl does not want to help us + extended_error = strerror(errno); + } else { + extended_error = AdditionalProcessInfo(*this, fl.l_pid); + } + if (flags.Lock() == FileLockType::WRITE_LOCK) { + // maybe we can get a read lock instead and tell this to the user. + fl.l_type = F_RDLCK; + rc = fcntl(fd, F_SETLK, &fl); + if (rc != -1) { // success! + extended_error += + ". However, you would be able to open this database in read-only mode, e.g. by " + "using the -readonly parameter in the CLI"; + } } } rc = close(fd); if (rc == -1) { - message += ". Also, failed closing file"; + extended_error += ". Also, failed closing file"; } - message += ". See also https://duckdb.org/docs/connect/concurrency"; + extended_error += ". See also https://duckdb.org/docs/connect/concurrency"; throw IOException("Could not set lock on file \"%s\": %s", {{"errno", std::to_string(retained_errno)}}, - path, message); + path, extended_error); } } } @@ -490,7 +514,8 @@ bool LocalFileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_ return false; #else int fd = handle.Cast().fd; - int res = fallocate(fd, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, offset_bytes, length_bytes); + int res = fallocate(fd, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, UnsafeNumericCast(offset_bytes), + UnsafeNumericCast(length_bytes)); return res == 0; #endif #else @@ -611,10 +636,6 @@ void LocalFileSystem::RemoveFile(const string &filename, optional_ptr &callback, FileOpener *opener) { - if (!DirectoryExists(directory, opener)) { - return false; - } - auto dir = opendir(directory.c_str()); if (!dir) { return false; @@ -633,11 +654,11 @@ bool LocalFileSystem::ListFiles(const string &directory, const std::function &column_names, + const vector &column_ids, ExtraOperatorInfo &extra_info) + : table_index(table_index), column_names(column_names), column_ids(column_ids), extra_info(extra_info) { +} + // Helper method to do Filter Pushdown into a MultiFileList -bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &options, LogicalGet &get, +bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &options, MultiFilePushdownInfo &info, vector> &filters, vector &expanded_files) { - unordered_map column_map; - for (idx_t i = 0; i < get.column_ids.size(); i++) { - if (!IsRowIdColumnId(get.column_ids[i])) { - column_map.insert({get.names[get.column_ids[i]], i}); + HivePartitioningFilterInfo filter_info; + for (idx_t i = 0; i < info.column_ids.size(); i++) { + if (!IsRowIdColumnId(info.column_ids[i])) { + filter_info.column_map.insert({info.column_names[info.column_ids[i]], i}); } } + filter_info.hive_enabled = options.hive_partitioning; + filter_info.filename_enabled = options.filename; auto start_files = expanded_files.size(); - HivePartitioning::ApplyFiltersToFileList(context, expanded_files, filters, column_map, get, - options.hive_partitioning, options.filename); + HivePartitioning::ApplyFiltersToFileList(context, expanded_files, filters, filter_info, info); if (expanded_files.size() != start_files) { return true; @@ -34,6 +45,29 @@ bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &opti return false; } +bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &options, const vector &names, + const vector &types, const vector &column_ids, + const TableFilterSet &filters, vector &expanded_files) { + idx_t table_index = 0; + ExtraOperatorInfo extra_info; + + // construct the pushdown info + MultiFilePushdownInfo info(table_index, names, column_ids, extra_info); + + // construct the set of expressions from the table filters + vector> filter_expressions; + for (auto &entry : filters.filters) { + auto column_idx = column_ids[entry.first]; + auto column_ref = + make_uniq(types[column_idx], ColumnBinding(table_index, entry.first)); + auto filter_expr = entry.second->ToExpression(*column_ref); + filter_expressions.push_back(std::move(filter_expr)); + } + + // call the original PushdownInternal method + return PushdownInternal(context, options, info, filter_expressions, expanded_files); +} + //===--------------------------------------------------------------------===// // MultiFileListIterator //===--------------------------------------------------------------------===// @@ -124,12 +158,25 @@ bool MultiFileList::Scan(MultiFileListScanData &iterator, string &result_file) { } unique_ptr MultiFileList::ComplexFilterPushdown(ClientContext &context, - const MultiFileReaderOptions &options, LogicalGet &get, + const MultiFileReaderOptions &options, + MultiFilePushdownInfo &info, vector> &filters) { // By default the filter pushdown into a multifilelist does nothing return nullptr; } +unique_ptr +MultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, + const vector &names, const vector &types, + const vector &column_ids, TableFilterSet &filters) const { + // By default the filter pushdown into a multifilelist does nothing + return nullptr; +} + +unique_ptr MultiFileList::GetCardinality(ClientContext &context) { + return nullptr; +} + string MultiFileList::GetFirstFile() { return GetFile(0); } @@ -147,7 +194,7 @@ SimpleMultiFileList::SimpleMultiFileList(vector paths_p) unique_ptr SimpleMultiFileList::ComplexFilterPushdown(ClientContext &context_p, const MultiFileReaderOptions &options, - LogicalGet &get, + MultiFilePushdownInfo &info, vector> &filters) { if (!options.hive_partitioning && !options.filename) { return nullptr; @@ -155,7 +202,7 @@ unique_ptr SimpleMultiFileList::ComplexFilterPushdown(ClientConte // FIXME: don't copy list until first file is filtered auto file_copy = paths; - auto res = PushdownInternal(context_p, options, get, filters, file_copy); + auto res = PushdownInternal(context_p, options, info, filters, file_copy); if (res) { return make_uniq(file_copy); @@ -164,6 +211,24 @@ unique_ptr SimpleMultiFileList::ComplexFilterPushdown(ClientConte return nullptr; } +unique_ptr +SimpleMultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, + const vector &names, const vector &types, + const vector &column_ids, TableFilterSet &filters) const { + if (!options.hive_partitioning && !options.filename) { + return nullptr; + } + + // FIXME: don't copy list until first file is filtered + auto file_copy = paths; + auto res = PushdownInternal(context, options, names, types, column_ids, filters, file_copy); + if (res) { + return make_uniq(file_copy); + } + + return nullptr; +} + vector SimpleMultiFileList::GetAllFiles() { return paths; } @@ -199,21 +264,20 @@ GlobMultiFileList::GlobMultiFileList(ClientContext &context_p, vector pa unique_ptr GlobMultiFileList::ComplexFilterPushdown(ClientContext &context_p, const MultiFileReaderOptions &options, - LogicalGet &get, + MultiFilePushdownInfo &info, vector> &filters) { lock_guard lck(lock); // Expand all // FIXME: lazy expansion // FIXME: push down filters into glob - while (ExpandPathInternal()) { + while (ExpandNextPath()) { } if (!options.hive_partitioning && !options.filename) { return nullptr; } - auto res = PushdownInternal(context, options, get, filters, expanded_files); - + auto res = PushdownInternal(context, options, info, filters, expanded_files); if (res) { return make_uniq(expanded_files); } @@ -221,16 +285,40 @@ unique_ptr GlobMultiFileList::ComplexFilterPushdown(ClientContext return nullptr; } +unique_ptr +GlobMultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, + const vector &names, const vector &types, + const vector &column_ids, TableFilterSet &filters) const { + if (!options.hive_partitioning && !options.filename) { + return nullptr; + } + lock_guard lck(lock); + + // Expand all paths into a copy + // FIXME: lazy expansion and push filters into glob + idx_t path_index = current_path; + auto file_list = expanded_files; + while (ExpandPathInternal(path_index, file_list)) { + } + + auto res = PushdownInternal(context, options, names, types, column_ids, filters, file_list); + if (res) { + return make_uniq(file_list); + } + + return nullptr; +} + vector GlobMultiFileList::GetAllFiles() { lock_guard lck(lock); - while (ExpandPathInternal()) { + while (ExpandNextPath()) { } return expanded_files; } idx_t GlobMultiFileList::GetTotalFileCount() { lock_guard lck(lock); - while (ExpandPathInternal()) { + while (ExpandNextPath()) { } return expanded_files.size(); } @@ -255,7 +343,7 @@ string GlobMultiFileList::GetFile(idx_t i) { string GlobMultiFileList::GetFileInternal(idx_t i) { while (expanded_files.size() <= i) { - if (!ExpandPathInternal()) { + if (!ExpandNextPath()) { return ""; } } @@ -263,22 +351,25 @@ string GlobMultiFileList::GetFileInternal(idx_t i) { return expanded_files[i]; } -bool GlobMultiFileList::ExpandPathInternal() { - if (IsFullyExpanded()) { +bool GlobMultiFileList::ExpandPathInternal(idx_t ¤t_path, vector &result) const { + if (current_path >= paths.size()) { return false; } auto &fs = FileSystem::GetFileSystem(context); auto glob_files = fs.GlobFiles(paths[current_path], context, glob_options); std::sort(glob_files.begin(), glob_files.end()); - expanded_files.insert(expanded_files.end(), glob_files.begin(), glob_files.end()); + result.insert(result.end(), glob_files.begin(), glob_files.end()); current_path++; - return true; } -bool GlobMultiFileList::IsFullyExpanded() { +bool GlobMultiFileList::ExpandNextPath() { + return ExpandPathInternal(current_path, expanded_files); +} + +bool GlobMultiFileList::IsFullyExpanded() const { return current_path == paths.size(); } diff --git a/src/duckdb/src/common/multi_file_reader.cpp b/src/duckdb/src/common/multi_file_reader.cpp index fec6664c..b964a935 100644 --- a/src/duckdb/src/common/multi_file_reader.cpp +++ b/src/duckdb/src/common/multi_file_reader.cpp @@ -7,7 +7,7 @@ #include "duckdb/function/function_set.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/main/config.hpp" -#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/common/string_util.hpp" #include @@ -38,8 +38,16 @@ unique_ptr MultiFileReader::CreateDefault(const string &functio return res; } +Value MultiFileReader::CreateValueFromFileList(const vector &file_list) { + vector files; + for (auto &file : file_list) { + files.push_back(file); + } + return Value::LIST(std::move(files)); +} + void MultiFileReader::AddParameters(TableFunction &table_function) { - table_function.named_parameters["filename"] = LogicalType::BOOLEAN; + table_function.named_parameters["filename"] = LogicalType::ANY; table_function.named_parameters["hive_partitioning"] = LogicalType::BOOLEAN; table_function.named_parameters["union_by_name"] = LogicalType::BOOLEAN; table_function.named_parameters["hive_types"] = LogicalType::ANY; @@ -95,7 +103,18 @@ bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFile ClientContext &context) { auto loption = StringUtil::Lower(key); if (loption == "filename") { - options.filename = BooleanValue::Get(val); + if (val.type() == LogicalType::VARCHAR) { + // If not, we interpret it as the name of the column containing the filename + options.filename = true; + options.filename_column = StringValue::Get(val); + } else { + Value boolean_value; + string error_message; + if (val.DefaultTryCastAs(LogicalType::BOOLEAN, boolean_value, &error_message)) { + // If the argument can be cast to boolean, we just interpret it as a boolean + options.filename = BooleanValue::Get(boolean_value); + } + } } else if (loption == "hive_partitioning") { options.hive_partitioning = BooleanValue::Get(val); options.auto_detect_hive_partitioning = false; @@ -130,9 +149,19 @@ bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFile } unique_ptr MultiFileReader::ComplexFilterPushdown(ClientContext &context, MultiFileList &files, - const MultiFileReaderOptions &options, LogicalGet &get, + const MultiFileReaderOptions &options, + MultiFilePushdownInfo &info, vector> &filters) { - return files.ComplexFilterPushdown(context, options, get, filters); + return files.ComplexFilterPushdown(context, options, info, filters); +} + +unique_ptr MultiFileReader::DynamicFilterPushdown(ClientContext &context, const MultiFileList &files, + const MultiFileReaderOptions &options, + const vector &names, + const vector &types, + const vector &column_ids, + TableFilterSet &filters) { + return files.DynamicFilterPushdown(context, options, names, types, column_ids, filters); } bool MultiFileReader::Bind(MultiFileReaderOptions &options, MultiFileList &files, vector &return_types, @@ -146,12 +175,14 @@ void MultiFileReader::BindOptions(MultiFileReaderOptions &options, MultiFileList MultiFileReaderBindData &bind_data) { // Add generated constant column for filename if (options.filename) { - if (std::find(names.begin(), names.end(), "filename") != names.end()) { - throw BinderException("Using filename option on file with column named filename is not supported"); + if (std::find(names.begin(), names.end(), options.filename_column) != names.end()) { + throw BinderException("Option filename adds column \"%s\", but a column with this name is also in the " + "file. Try setting a different name: filename=''", + options.filename_column); } bind_data.filename_idx = names.size(); return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("filename"); + names.emplace_back(options.filename_column); } // Add generated constant columns from hive partitioning scheme @@ -392,7 +423,7 @@ HivePartitioningIndex::HivePartitioningIndex(string value_p, idx_t index) : valu } void MultiFileReaderOptions::AddBatchInfo(BindInfo &bind_info) const { - bind_info.InsertOption("filename", Value::BOOLEAN(filename)); + bind_info.InsertOption("filename", Value(filename_column)); bind_info.InsertOption("hive_partitioning", Value::BOOLEAN(hive_partitioning)); bind_info.InsertOption("auto_detect_hive_partitioning", Value::BOOLEAN(auto_detect_hive_partitioning)); bind_info.InsertOption("union_by_name", Value::BOOLEAN(union_by_name)); @@ -421,35 +452,23 @@ void UnionByName::CombineUnionTypes(const vector &col_names, const vecto } bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(MultiFileList &files, ClientContext &context) { - std::unordered_set partitions; - auto &fs = FileSystem::GetFileSystem(context); - auto first_file = files.GetFirstFile(); - auto splits_first_file = StringUtil::Split(first_file, fs.PathSeparator(first_file)); - if (splits_first_file.size() < 2) { - return false; - } - for (auto &split : splits_first_file) { - auto partition = StringUtil::Split(split, "="); - if (partition.size() == 2) { - partitions.insert(partition.front()); - } - } + auto partitions = HivePartitioning::Parse(first_file); if (partitions.empty()) { + // no partitions found in first file return false; } for (const auto &file : files.Files()) { - auto splits = StringUtil::Split(file, fs.PathSeparator(file)); - if (splits.size() != splits_first_file.size()) { + auto new_partitions = HivePartitioning::Parse(file); + if (new_partitions.size() != partitions.size()) { + // partition count mismatch return false; } - for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { - auto part = StringUtil::Split(*it, "="); - if (part.size() != 2) { - continue; - } - if (partitions.find(part.front()) == partitions.end()) { + for (auto &part : new_partitions) { + auto entry = partitions.find(part.first); + if (entry == partitions.end()) { + // differing partitions between files return false; } } @@ -459,21 +478,9 @@ bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(MultiFileList &f void MultiFileReaderOptions::AutoDetectHiveTypesInternal(MultiFileList &files, ClientContext &context) { const LogicalType candidates[] = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::BIGINT}; - auto &fs = FileSystem::GetFileSystem(context); - unordered_map detected_types; for (const auto &file : files.Files()) { - unordered_map partitions; - auto splits = StringUtil::Split(file, fs.PathSeparator(file)); - if (splits.size() < 2) { - return; - } - for (auto it = splits.begin(); it != std::prev(splits.end()); it++) { - auto part = StringUtil::Split(*it, "="); - if (part.size() == 2) { - partitions[part.front()] = part.back(); - } - } + auto partitions = HivePartitioning::Parse(file); if (partitions.empty()) { return; } @@ -545,24 +552,18 @@ LogicalType MultiFileReaderOptions::GetHiveLogicalType(const string &hive_partit } return LogicalType::VARCHAR; } -Value MultiFileReaderOptions::GetHivePartitionValue(const string &base, const string &entry, - ClientContext &context) const { - Value value(base); - auto it = hive_types_schema.find(entry); - if (it == hive_types_schema.end()) { - return value; - } - // Handle nulls - if (base.empty() || StringUtil::CIEquals(base, "NULL")) { - return Value(it->second); - } +bool MultiFileReaderOptions::AnySet() { + return filename || hive_partitioning || union_by_name; +} - if (!value.TryCastAs(context, it->second)) { - throw InvalidInputException("Unable to cast '%s' (from hive partition column '%s') to: '%s'", value.ToString(), - StringUtil::Upper(it->first), it->second.ToString()); +Value MultiFileReaderOptions::GetHivePartitionValue(const string &value, const string &key, + ClientContext &context) const { + auto it = hive_types_schema.find(key); + if (it == hive_types_schema.end()) { + return HivePartitioning::GetValue(context, key, value, LogicalType::VARCHAR); } - return value; + return HivePartitioning::GetValue(context, key, value, it->second); } } // namespace duckdb diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp index 4d69aaf1..0459ef82 100644 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -922,26 +922,45 @@ bool TryCast::Operation(double input, double &result, bool strict) { //===--------------------------------------------------------------------===// template <> bool TryCast::Operation(string_t input, bool &result, bool strict) { - auto input_data = input.GetData(); + auto input_data = reinterpret_cast(input.GetData()); auto input_size = input.GetSize(); switch (input_size) { case 1: { - char c = UnsafeNumericCast(std::tolower(*input_data)); - if (c == 't' || (!strict && c == '1')) { + unsigned char c = UnsafeNumericCast(std::tolower(*input_data)); + if (c == 't' || (!strict && c == 'y') || (!strict && c == '1')) { result = true; return true; - } else if (c == 'f' || (!strict && c == '0')) { + } else if (c == 'f' || (!strict && c == 'n') || (!strict && c == '0')) { result = false; return true; } return false; } + case 2: { + unsigned char n = UnsafeNumericCast(std::tolower(input_data[0])); + unsigned char o = UnsafeNumericCast(std::tolower(input_data[1])); + if (n == 'n' && o == 'o') { + result = false; + return true; + } + return false; + } + case 3: { + unsigned char y = UnsafeNumericCast(std::tolower(input_data[0])); + unsigned char e = UnsafeNumericCast(std::tolower(input_data[1])); + unsigned char s = UnsafeNumericCast(std::tolower(input_data[2])); + if (y == 'y' && e == 'e' && s == 's') { + result = true; + return true; + } + return false; + } case 4: { - char t = UnsafeNumericCast(std::tolower(input_data[0])); - char r = UnsafeNumericCast(std::tolower(input_data[1])); - char u = UnsafeNumericCast(std::tolower(input_data[2])); - char e = UnsafeNumericCast(std::tolower(input_data[3])); + unsigned char t = UnsafeNumericCast(std::tolower(input_data[0])); + unsigned char r = UnsafeNumericCast(std::tolower(input_data[1])); + unsigned char u = UnsafeNumericCast(std::tolower(input_data[2])); + unsigned char e = UnsafeNumericCast(std::tolower(input_data[3])); if (t == 't' && r == 'r' && u == 'u' && e == 'e') { result = true; return true; @@ -949,11 +968,11 @@ bool TryCast::Operation(string_t input, bool &result, bool strict) { return false; } case 5: { - char f = UnsafeNumericCast(std::tolower(input_data[0])); - char a = UnsafeNumericCast(std::tolower(input_data[1])); - char l = UnsafeNumericCast(std::tolower(input_data[2])); - char s = UnsafeNumericCast(std::tolower(input_data[3])); - char e = UnsafeNumericCast(std::tolower(input_data[4])); + unsigned char f = UnsafeNumericCast(std::tolower(input_data[0])); + unsigned char a = UnsafeNumericCast(std::tolower(input_data[1])); + unsigned char l = UnsafeNumericCast(std::tolower(input_data[2])); + unsigned char s = UnsafeNumericCast(std::tolower(input_data[3])); + unsigned char e = UnsafeNumericCast(std::tolower(input_data[4])); if (f == 'f' && a == 'a' && l == 'l' && s == 's' && e == 'e') { result = false; return true; @@ -1125,8 +1144,8 @@ bool TryCast::Operation(interval_t input, interval_t &result, bool strict) { // Non-Standard Timestamps //===--------------------------------------------------------------------===// template <> -duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(CastTimestampNsToUs::Operation(input), result); +duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_ns_t input, Vector &result) { + return StringCast::Operation(input, result); } template <> duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_t input, Vector &result) { @@ -1258,20 +1277,8 @@ dtime_t CastTimestampSecToTime::Operation(timestamp_t input) { // Cast To Timestamp //===--------------------------------------------------------------------===// template <> -bool TryCastToTimestampNS::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - if (!Timestamp::IsFinite(result)) { - return true; - } - - int64_t nanoseconds; - if (!Timestamp::TryGetEpochNanoSeconds(result, nanoseconds)) { - throw ConversionException("Could not convert VARCHAR value '%s' to Timestamp(NS)", input.GetString()); - } - result = nanoseconds; - return true; +bool TryCastToTimestampNS::Operation(string_t input, timestamp_ns_t &result, bool strict) { + return TryCast::Operation(input, result, strict); } template <> @@ -1293,7 +1300,7 @@ bool TryCastToTimestampSec::Operation(string_t input, timestamp_t &result, bool } template <> -bool TryCastToTimestampNS::Operation(date_t input, timestamp_t &result, bool strict) { +bool TryCastToTimestampNS::Operation(date_t input, timestamp_ns_t &result, bool strict) { if (!TryCast::Operation(input, result, strict)) { return false; } @@ -1558,11 +1565,27 @@ bool TryCast::Operation(string_t input, timestamp_t &result, bool strict) { return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; } +template <> +bool TryCast::Operation(string_t input, timestamp_ns_t &result, bool strict) { + return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; +} + template <> timestamp_t Cast::Operation(string_t input) { return Timestamp::FromCString(input.GetData(), input.GetSize()); } +template <> +timestamp_ns_t Cast::Operation(string_t input) { + int32_t nanos; + const auto ts = Timestamp::FromCString(input.GetData(), input.GetSize(), &nanos); + timestamp_ns_t result; + if (!Timestamp::TryFromTimestampNanos(ts, nanos, result)) { + throw ConversionException(Timestamp::ConversionError(input)); + } + return result; +} + //===--------------------------------------------------------------------===// // Cast From Interval //===--------------------------------------------------------------------===// @@ -2275,9 +2298,6 @@ bool TryCastToDecimal::Operation(uhugeint_t input, hugeint_t &result, CastParame template bool DoubleToDecimalCast(SRC input, DST &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { double value = input * NumericHelper::DOUBLE_POWERS_OF_TEN[scale]; - // Add the sign (-1, 0, 1) times a tiny value to fix floating point issues (issue 3091) - double sign = (double(0) < value) - (value < double(0)); - value += 1e-9 * sign; if (value <= -NumericHelper::DOUBLE_POWERS_OF_TEN[width] || value >= NumericHelper::DOUBLE_POWERS_OF_TEN[width]) { string error = StringUtil::Format("Could not cast value %f to DECIMAL(%d,%d)", value, width, scale); HandleCastError::AssignError(error, parameters); @@ -2611,9 +2631,88 @@ bool TryCastFromDecimal::Operation(hugeint_t input, uhugeint_t &result, CastPara //===--------------------------------------------------------------------===// // Decimal -> Float/Double Cast //===--------------------------------------------------------------------===// +template +static bool IsRepresentableExactly(SRC input, DST); + +template <> +bool IsRepresentableExactly(int16_t input, float dst) { + return true; +} + +const int64_t MAX_INT_REPRESENTABLE_IN_FLOAT = 0x001000000LL; +const int64_t MAX_INT_REPRESENTABLE_IN_DOUBLE = 0x0020000000000000LL; + +template <> +bool IsRepresentableExactly(int32_t input, float dst) { + return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); +} + +template <> +bool IsRepresentableExactly(int64_t input, float dst) { + return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); +} + +template <> +bool IsRepresentableExactly(hugeint_t input, float dst) { + return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); +} + +template <> +bool IsRepresentableExactly(int16_t input, double dst) { + return true; +} + +template <> +bool IsRepresentableExactly(int32_t input, double dst) { + return true; +} + +template <> +bool IsRepresentableExactly(int64_t input, double dst) { + return (input <= MAX_INT_REPRESENTABLE_IN_DOUBLE && input >= -MAX_INT_REPRESENTABLE_IN_DOUBLE); +} + +template <> +bool IsRepresentableExactly(hugeint_t input, double dst) { + return (input <= MAX_INT_REPRESENTABLE_IN_DOUBLE && input >= -MAX_INT_REPRESENTABLE_IN_DOUBLE); +} + +template +static SRC GetPowerOfTen(SRC input, uint8_t scale) { + return static_cast(NumericHelper::POWERS_OF_TEN[scale]); +} + +template <> +hugeint_t GetPowerOfTen(hugeint_t input, uint8_t scale) { + return Hugeint::POWERS_OF_TEN[scale]; +} + +template +static void GetDivMod(SRC lhs, SRC rhs, SRC &div, SRC &mod) { + div = lhs / rhs; + mod = lhs % rhs; +} + +template <> +void GetDivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &div, hugeint_t &mod) { + div = Hugeint::DivMod(lhs, rhs, mod); +} + template bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { - result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); + if (IsRepresentableExactly(input, DST(0.0)) || scale == 0) { + // Fast path, integer is representable exaclty as a float/double + result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); + return true; + } + auto power_of_ten = GetPowerOfTen(input, scale); + + SRC div = 0; + SRC mod = 0; + GetDivMod(input, power_of_ten, div, mod); + + result = Cast::Operation(div) + + Cast::Operation(mod) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); return true; } diff --git a/src/duckdb/src/common/operator/string_cast.cpp b/src/duckdb/src/common/operator/string_cast.cpp index 0e0665a9..64d9c53e 100644 --- a/src/duckdb/src/common/operator/string_cast.cpp +++ b/src/duckdb/src/common/operator/string_cast.cpp @@ -75,7 +75,7 @@ string_t StringCast::Operation(double input, Vector &vector) { template <> string_t StringCast::Operation(interval_t input, Vector &vector) { - char buffer[70]; + char buffer[70] = {}; idx_t length = IntervalToStringCast::Format(input, buffer); return StringVector::AddString(vector, buffer, length); } @@ -113,7 +113,7 @@ duckdb::string_t StringCast::Operation(dtime_t input, Vector &vector) { int32_t time[4]; Time::Convert(input, time[0], time[1], time[2], time[3]); - char micro_buffer[10]; + char micro_buffer[10] = {}; idx_t length = TimeToStringCast::Length(time, micro_buffer); string_t result = StringVector::EmptyString(vector, length); @@ -125,8 +125,8 @@ duckdb::string_t StringCast::Operation(dtime_t input, Vector &vector) { return result; } -template <> -duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { +template +duckdb::string_t StringFromTimestamp(timestamp_t input, Vector &vector) { if (input == timestamp_t::infinity()) { return StringVector::AddString(vector, Date::PINF); } else if (input == timestamp_t::ninfinity()) { @@ -134,7 +134,16 @@ duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { } date_t date_entry; dtime_t time_entry; - Timestamp::Convert(input, date_entry, time_entry); + int32_t picos = 0; + if (HAS_NANOS) { + timestamp_ns_t ns; + ns.value = input.value; + Timestamp::Convert(ns, date_entry, time_entry, picos); + // Use picoseconds so we have 6 digits + picos *= 1000; + } else { + Timestamp::Convert(input, date_entry, time_entry); + } int32_t date[3], time[4]; Date::Convert(date_entry, date[0], date[1], date[2]); @@ -143,22 +152,44 @@ duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { // format for timestamp is DATE TIME (separated by space) idx_t year_length; bool add_bc; - char micro_buffer[6]; + char micro_buffer[6] = {}; + char nano_buffer[6] = {}; idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - idx_t length = date_length + time_length + 1; + idx_t nano_length = 0; + if (picos) { + // If there are ps, we need all the µs + time_length = 15; + nano_length = 6; + nano_length -= NumericCast(TimeToStringCast::FormatMicros(picos, nano_buffer)); + } + const idx_t length = date_length + 1 + time_length + nano_length; string_t result = StringVector::EmptyString(vector, length); auto data = result.GetDataWriteable(); DateToStringCast::Format(data, date, year_length, add_bc); - data[date_length] = ' '; - TimeToStringCast::Format(data + date_length + 1, time_length, time, micro_buffer); + data += date_length; + *data++ = ' '; + TimeToStringCast::Format(data, time_length, time, micro_buffer); + data += time_length; + memcpy(data, nano_buffer, nano_length); + D_ASSERT(data + nano_length <= result.GetDataWriteable() + length); result.Finalize(); return result; } +template <> +duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { + return StringFromTimestamp(input, vector); +} + +template <> +duckdb::string_t StringCast::Operation(timestamp_ns_t input, Vector &vector) { + return StringFromTimestamp(input, vector); +} + template <> duckdb::string_t StringCast::Operation(duckdb::string_t input, Vector &result) { return StringVector::AddStringOrBlob(result, input); @@ -169,7 +200,7 @@ string_t StringCastTZ::Operation(dtime_tz_t input, Vector &vector) { int32_t time[4]; Time::Convert(input.time(), time[0], time[1], time[2], time[3]); - char micro_buffer[10]; + char micro_buffer[10] = {}; const auto time_length = TimeToStringCast::Length(time, micro_buffer); idx_t length = time_length; @@ -243,7 +274,7 @@ string_t StringCastTZ::Operation(timestamp_t input, Vector &vector) { // format for timestamptz is DATE TIME+00 (separated by space) idx_t year_length; bool add_bc; - char micro_buffer[6]; + char micro_buffer[6] = {}; const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); const idx_t time_length = TimeToStringCast::Length(time, micro_buffer); const idx_t length = date_length + 1 + time_length + 3; diff --git a/src/duckdb/src/common/progress_bar/progress_bar.cpp b/src/duckdb/src/common/progress_bar/progress_bar.cpp index 720f5499..13d1d509 100644 --- a/src/duckdb/src/common/progress_bar/progress_bar.cpp +++ b/src/duckdb/src/common/progress_bar/progress_bar.cpp @@ -84,7 +84,7 @@ bool ProgressBar::ShouldPrint(bool final) const { return false; } // FIXME - do we need to check supported before running `profiler.Elapsed()` ? - auto sufficient_time_elapsed = profiler.Elapsed() > show_progress_after / 1000.0; + auto sufficient_time_elapsed = profiler.Elapsed() > static_cast(show_progress_after) / 1000.0; if (!sufficient_time_elapsed) { // Don't print yet return false; @@ -121,7 +121,7 @@ void ProgressBar::Update(bool final) { if (final) { FinishProgressBarPrint(); } else { - PrintProgress(NumericCast(query_progress.percentage.load())); + PrintProgress(LossyNumericCast(query_progress.percentage.load())); } #endif } diff --git a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp index e2079712..912b8ccc 100644 --- a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp +++ b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp @@ -40,7 +40,7 @@ void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage) { } if (i < PROGRESS_BAR_WIDTH) { // print a partial block based on the percentage of the progress bar remaining - idx_t index = idx_t((blocks_to_draw - idx_t(blocks_to_draw)) * PARTIAL_BLOCK_COUNT); + idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * PARTIAL_BLOCK_COUNT); if (index >= PARTIAL_BLOCK_COUNT) { index = PARTIAL_BLOCK_COUNT - 1; } diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp index 69b26e05..3e8dee30 100644 --- a/src/duckdb/src/common/radix_partitioning.cpp +++ b/src/duckdb/src/common/radix_partitioning.cpp @@ -1,7 +1,6 @@ #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/types/column/partitioned_column_data.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector_operations/binary_executor.hpp" #include "duckdb/common/vector_operations/unary_executor.hpp" @@ -13,20 +12,20 @@ template struct RadixPartitioningConstants { public: //! Bitmask of the upper bits starting at the 5th byte - static constexpr const idx_t NUM_PARTITIONS = RadixPartitioning::NumberOfPartitions(radix_bits); - static constexpr const idx_t SHIFT = RadixPartitioning::Shift(radix_bits); - static constexpr const hash_t MASK = RadixPartitioning::Mask(radix_bits); + static constexpr idx_t NUM_PARTITIONS = RadixPartitioning::NumberOfPartitions(radix_bits); + static constexpr idx_t SHIFT = RadixPartitioning::Shift(radix_bits); + static constexpr hash_t MASK = RadixPartitioning::Mask(radix_bits); public: //! Apply bitmask and right shift to get a number between 0 and NUM_PARTITIONS - static inline hash_t ApplyMask(hash_t hash) { + static hash_t ApplyMask(const hash_t hash) { D_ASSERT((hash & MASK) >> SHIFT < NUM_PARTITIONS); return (hash & MASK) >> SHIFT; } }; template -RETURN_TYPE RadixBitsSwitch(idx_t radix_bits, ARGS &&... args) { +RETURN_TYPE RadixBitsSwitch(const idx_t radix_bits, ARGS &&... args) { D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); switch (radix_bits) { case 0: @@ -71,7 +70,7 @@ struct RadixLessThan { struct SelectFunctor { template - static idx_t Operation(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t cutoff, + static idx_t Operation(Vector &hashes, const SelectionVector *sel, const idx_t count, const idx_t cutoff, SelectionVector *true_sel, SelectionVector *false_sel) { Vector cutoff_vector(Value::HASH(cutoff)); return BinaryExecutor::Select>(hashes, cutoff_vector, sel, count, @@ -79,18 +78,24 @@ struct SelectFunctor { } }; -idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, idx_t cutoff, - SelectionVector *true_sel, SelectionVector *false_sel) { +idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, const idx_t count, const idx_t radix_bits, + const idx_t cutoff, SelectionVector *true_sel, SelectionVector *false_sel) { return RadixBitsSwitch(radix_bits, hashes, sel, count, cutoff, true_sel, false_sel); } struct ComputePartitionIndicesFunctor { template - static void Operation(Vector &hashes, Vector &partition_indices, idx_t count) { - UnaryExecutor::Execute(hashes, partition_indices, count, [&](hash_t hash) { - using CONSTANTS = RadixPartitioningConstants; - return CONSTANTS::ApplyMask(hash); - }); + static void Operation(Vector &hashes, Vector &partition_indices, const SelectionVector &append_sel, + const idx_t append_count) { + using CONSTANTS = RadixPartitioningConstants; + if (append_sel.IsSet()) { + auto hashes_sliced = Vector(hashes, append_sel, append_count); + UnaryExecutor::Execute(hashes_sliced, partition_indices, append_count, + [&](hash_t hash) { return CONSTANTS::ApplyMask(hash); }); + } else { + UnaryExecutor::Execute(hashes, partition_indices, append_count, + [&](hash_t hash) { return CONSTANTS::ApplyMask(hash); }); + } } }; @@ -130,20 +135,23 @@ void RadixPartitionedColumnData::InitializeAppendStateInternal(PartitionedColumn partitions[i]->InitializeAppend(*state.partition_append_states[i]); state.partition_buffers.emplace_back(CreatePartitionBuffer()); } + + // Initialize fixed-size map + state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); } void RadixPartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); D_ASSERT(state.partition_buffers.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - input.size()); + *FlatVector::IncrementalSelectionVector(), input.size()); } //===--------------------------------------------------------------------===// // Tuple Data Partitioning //===--------------------------------------------------------------------===// RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout_p, - idx_t radix_bits_p, idx_t hash_col_idx_p) + const idx_t radix_bits_p, const idx_t hash_col_idx_p) : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_p.Copy()), radix_bits(radix_bits_p), hash_col_idx(hash_col_idx_p) { D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); @@ -172,12 +180,12 @@ void RadixPartitionedTupleData::Initialize() { } void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const { + const TupleDataPinProperties properties) const { // Init pin state per partition const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); state.partition_pin_states.reserve(num_partitions); for (idx_t i = 0; i < num_partitions; i++) { - state.partition_pin_states.emplace_back(make_uniq()); + state.partition_pin_states.emplace_back(make_unsafe_uniq()); partitions[i]->InitializeAppend(*state.partition_pin_states[i], properties); } @@ -194,10 +202,11 @@ void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDa state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); } -void RadixPartitionedTupleData::ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input) { +void RadixPartitionedTupleData::ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input, + const SelectionVector &append_sel, const idx_t append_count) { D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - input.size()); + append_sel, append_count); } void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, idx_t count, @@ -205,7 +214,8 @@ void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, i Vector intermediate(LogicalType::HASH); partitions[0]->Gather(row_locations, *FlatVector::IncrementalSelectionVector(), count, hash_col_idx, intermediate, *FlatVector::IncrementalSelectionVector(), nullptr); - RadixBitsSwitch(radix_bits, intermediate, partition_indices, count); + RadixBitsSwitch(radix_bits, intermediate, partition_indices, + *FlatVector::IncrementalSelectionVector(), count); } void RadixPartitionedTupleData::RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, diff --git a/src/duckdb/src/common/random_engine.cpp b/src/duckdb/src/common/random_engine.cpp index acbda8b3..e51f7100 100644 --- a/src/duckdb/src/common/random_engine.cpp +++ b/src/duckdb/src/common/random_engine.cpp @@ -35,6 +35,10 @@ uint32_t RandomEngine::NextRandomInteger() { return random_state->pcg(); } +uint32_t RandomEngine::NextRandomInteger(uint32_t min, uint32_t max) { + return min + static_cast(NextRandom() * double(max - min)); +} + void RandomEngine::SetSeed(uint32_t seed) { random_state->pcg.seed(seed); } diff --git a/src/duckdb/src/common/re2_regex.cpp b/src/duckdb/src/common/re2_regex.cpp index 0b981062..e934e105 100644 --- a/src/duckdb/src/common/re2_regex.cpp +++ b/src/duckdb/src/common/re2_regex.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/exception.hpp" #include "duckdb/common/vector.hpp" #include @@ -7,56 +8,90 @@ namespace duckdb_re2 { +static size_t GetMultibyteCharLength(const char c) { + if ((c & 0x80) == 0) { + return 1; // 1-byte character (ASCII) + } else if ((c & 0xE0) == 0xC0) { + return 2; // 2-byte character + } else if ((c & 0xF0) == 0xE0) { + return 3; // 3-byte character + } else if ((c & 0xF8) == 0xF0) { + return 4; // 4-byte character + } else { + return 0; // invalid UTF-8leading byte + } +} + Regex::Regex(const std::string &pattern, RegexOptions options) { RE2::Options o; o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); regex = duckdb::make_shared_ptr(StringPiece(pattern), o); } -bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, - size_t end) { - auto ®ex = r.GetRegex(); +bool RegexSearchInternal(const char *input_data, size_t input_size, Match &match, const RE2 ®ex, RE2::Anchor anchor, + size_t start, size_t end) { duckdb::vector target_groups; auto group_count = duckdb::UnsafeNumericCast(regex.NumberOfCapturingGroups() + 1); target_groups.resize(group_count); match.groups.clear(); - if (!regex.Match(StringPiece(input), start, end, anchor, target_groups.data(), + if (!regex.Match(StringPiece(input_data, input_size), start, end, anchor, target_groups.data(), duckdb::UnsafeNumericCast(group_count))) { return false; } for (auto &group : target_groups) { GroupMatch group_match; group_match.text = group.ToString(); - group_match.position = group.data() != nullptr ? duckdb::NumericCast(group.data() - input) : 0; + group_match.position = group.data() != nullptr ? duckdb::NumericCast(group.data() - input_data) : 0; match.groups.emplace_back(group_match); } return true; } bool RegexSearch(const std::string &input, Match &match, const Regex ®ex) { - return RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, 0, input.size()); + auto input_sz = input.size(); + return RegexSearchInternal(input.c_str(), input_sz, match, regex.GetRegex(), RE2::UNANCHORED, 0, input_sz); } bool RegexMatch(const std::string &input, Match &match, const Regex ®ex) { - return RegexSearchInternal(input.c_str(), match, regex, RE2::ANCHOR_BOTH, 0, input.size()); + auto input_sz = input.size(); + return RegexSearchInternal(input.c_str(), input_sz, match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, input_sz); } bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex) { - return RegexSearchInternal(start, match, regex, RE2::ANCHOR_BOTH, 0, - duckdb::UnsafeNumericCast(end - start)); + auto sz = duckdb::UnsafeNumericCast(end - start); + return RegexSearchInternal(start, sz, match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, sz); } bool RegexMatch(const std::string &input, const Regex ®ex) { Match nop_match; - return RegexSearchInternal(input.c_str(), nop_match, regex, RE2::ANCHOR_BOTH, 0, input.size()); + auto input_sz = input.size(); + return RegexSearchInternal(input.c_str(), input_sz, nop_match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, input_sz); } duckdb::vector RegexFindAll(const std::string &input, const Regex ®ex) { + return RegexFindAll(input.c_str(), input.size(), regex.GetRegex()); +} + +duckdb::vector RegexFindAll(const char *input_data, size_t input_size, const RE2 ®ex) { duckdb::vector matches; size_t position = 0; Match match; - while (RegexSearchInternal(input.c_str(), match, regex, RE2::UNANCHORED, position, input.size())) { - position += match.position(0) + match.length(0); + while (RegexSearchInternal(input_data, input_size, match, regex, RE2::UNANCHORED, position, input_size)) { + if (match.length(0)) { + position = match.position(0) + match.length(0); + } else { // match.length(0) == 0 + auto next_char_length = GetMultibyteCharLength(input_data[match.position(0)]); + if (!next_char_length) { + throw duckdb::InvalidInputException("Invalid UTF-8 leading byte at position " + + std::to_string(match.position(0) + 1)); + } + if (match.position(0) + next_char_length < input_size) { + position = match.position(0) + next_char_length; + } else { + matches.emplace_back(match); + break; + } + } matches.emplace_back(match); } return matches; diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp new file mode 100644 index 00000000..9c7dca21 --- /dev/null +++ b/src/duckdb/src/common/render_tree.cpp @@ -0,0 +1,243 @@ +#include "duckdb/common/render_tree.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" + +namespace duckdb { + +struct PipelineRenderNode { + explicit PipelineRenderNode(const PhysicalOperator &op) : op(op) { + } + + const PhysicalOperator &op; + unique_ptr child; +}; + +} // namespace duckdb + +namespace { + +using duckdb::MaxValue; +using duckdb::PhysicalDelimJoin; +using duckdb::PhysicalOperator; +using duckdb::PhysicalOperatorType; +using duckdb::PhysicalPositionalScan; +using duckdb::PipelineRenderNode; +using duckdb::RenderTreeNode; + +class TreeChildrenIterator { +public: + template + static bool HasChildren(const T &op) { + return !op.children.empty(); + } + template + static void Iterate(const T &op, const std::function &callback) { + for (auto &child : op.children) { + callback(*child); + } + } +}; + +template <> +bool TreeChildrenIterator::HasChildren(const PhysicalOperator &op) { + switch (op.type) { + case PhysicalOperatorType::LEFT_DELIM_JOIN: + case PhysicalOperatorType::RIGHT_DELIM_JOIN: + case PhysicalOperatorType::POSITIONAL_SCAN: + return true; + default: + return !op.children.empty(); + } +} +template <> +void TreeChildrenIterator::Iterate(const PhysicalOperator &op, + const std::function &callback) { + for (auto &child : op.children) { + callback(*child); + } + if (op.type == PhysicalOperatorType::LEFT_DELIM_JOIN || op.type == PhysicalOperatorType::RIGHT_DELIM_JOIN) { + auto &delim = op.Cast(); + callback(*delim.join); + } else if ((op.type == PhysicalOperatorType::POSITIONAL_SCAN)) { + auto &pscan = op.Cast(); + for (auto &table : pscan.child_tables) { + callback(*table); + } + } +} + +template <> +bool TreeChildrenIterator::HasChildren(const PipelineRenderNode &op) { + return op.child.get(); +} + +template <> +void TreeChildrenIterator::Iterate(const PipelineRenderNode &op, + const std::function &callback) { + if (op.child) { + callback(*op.child); + } +} + +} // namespace + +namespace duckdb { + +template +static void GetTreeWidthHeight(const T &op, idx_t &width, idx_t &height) { + if (!TreeChildrenIterator::HasChildren(op)) { + width = 1; + height = 1; + return; + } + width = 0; + height = 0; + + TreeChildrenIterator::Iterate(op, [&](const T &child) { + idx_t child_width, child_height; + GetTreeWidthHeight(child, child_width, child_height); + width += child_width; + height = MaxValue(height, child_height); + }); + height++; +} + +static unique_ptr CreateNode(const LogicalOperator &op) { + return make_uniq(op.GetName(), op.ParamsToString()); +} + +static unique_ptr CreateNode(const PhysicalOperator &op) { + return make_uniq(op.GetName(), op.ParamsToString()); +} + +static unique_ptr CreateNode(const PipelineRenderNode &op) { + return CreateNode(op.op); +} + +static unique_ptr CreateNode(const ProfilingNode &op) { + auto &info = op.GetProfilingInfo(); + InsertionOrderPreservingMap extra_info; + if (info.Enabled(MetricsType::EXTRA_INFO)) { + extra_info = op.GetProfilingInfo().extra_info; + } + + string node_name = "QUERY"; + if (op.depth > 0) { + node_name = info.GetMetricAsString(MetricsType::OPERATOR_TYPE); + } + + auto result = make_uniq(node_name, extra_info); + if (info.Enabled(MetricsType::OPERATOR_CARDINALITY)) { + result->extra_text[RenderTreeNode::CARDINALITY] = info.GetMetricAsString(MetricsType::OPERATOR_CARDINALITY); + } + if (info.Enabled(MetricsType::OPERATOR_TIMING)) { + string timing = StringUtil::Format("%.2f", info.metrics.at(MetricsType::OPERATOR_TIMING).GetValue()); + result->extra_text[RenderTreeNode::TIMING] = timing + "s"; + } + return result; +} + +template +static idx_t CreateTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y) { + auto node = CreateNode(op); + + if (!TreeChildrenIterator::HasChildren(op)) { + result.SetNode(x, y, std::move(node)); + return 1; + } + idx_t width = 0; + // render the children of this node + TreeChildrenIterator::Iterate(op, [&](const T &child) { + auto child_x = x + width; + auto child_y = y + 1; + node->AddChildPosition(child_x, child_y); + width += CreateTreeRecursive(result, child, child_x, child_y); + }); + result.SetNode(x, y, std::move(node)); + return width; +} + +template +static unique_ptr CreateTree(const T &op) { + idx_t width, height; + GetTreeWidthHeight(op, width, height); + + auto result = make_uniq(width, height); + + // now fill in the tree + CreateTreeRecursive(*result, op, 0, 0); + return result; +} + +RenderTree::RenderTree(idx_t width_p, idx_t height_p) : width(width_p), height(height_p) { + nodes = make_uniq_array>((width + 1) * (height + 1)); +} + +optional_ptr RenderTree::GetNode(idx_t x, idx_t y) { + if (x >= width || y >= height) { + return nullptr; + } + return nodes[GetPosition(x, y)].get(); +} + +bool RenderTree::HasNode(idx_t x, idx_t y) { + if (x >= width || y >= height) { + return false; + } + return nodes[GetPosition(x, y)].get() != nullptr; +} + +idx_t RenderTree::GetPosition(idx_t x, idx_t y) { + return y * width + x; +} + +void RenderTree::SetNode(idx_t x, idx_t y, unique_ptr node) { + nodes[GetPosition(x, y)] = std::move(node); +} + +unique_ptr RenderTree::CreateRenderTree(const LogicalOperator &op) { + return CreateTree(op); +} + +unique_ptr RenderTree::CreateRenderTree(const PhysicalOperator &op) { + return CreateTree(op); +} + +unique_ptr RenderTree::CreateRenderTree(const ProfilingNode &op) { + return CreateTree(op); +} + +void RenderTree::SanitizeKeyNames() { + for (idx_t i = 0; i < width * height; i++) { + if (!nodes[i]) { + continue; + } + InsertionOrderPreservingMap new_map; + for (auto &entry : nodes[i]->extra_text) { + auto key = entry.first; + if (StringUtil::StartsWith(key, "__")) { + key = StringUtil::Replace(key, "__", ""); + key = StringUtil::Replace(key, "_", " "); + key = StringUtil::Title(key); + } + auto &value = entry.second; + new_map.insert(make_pair(key, value)); + } + nodes[i]->extra_text = std::move(new_map); + } +} + +unique_ptr RenderTree::CreateRenderTree(const Pipeline &pipeline) { + auto operators = pipeline.GetOperators(); + D_ASSERT(!operators.empty()); + unique_ptr node; + for (auto &op : operators) { + auto new_node = make_uniq(op.get()); + new_node->child = std::move(node); + node = std::move(new_node); + } + return CreateTree(*node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp index 0ea80035..fb433eb5 100644 --- a/src/duckdb/src/common/row_operations/row_aggregate.cpp +++ b/src/duckdb/src/common/row_operations/row_aggregate.cpp @@ -25,7 +25,7 @@ void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, for (idx_t i = 0; i < count; ++i) { auto row_idx = sel.get_index(i); auto row = pointers[row_idx]; - aggr.function.initialize(row + offsets[aggr_idx]); + aggr.function.initialize(aggr.function, row + offsets[aggr_idx]); } ++aggr_idx; } diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp index 0073cefc..b421fc1c 100644 --- a/src/duckdb/src/common/row_operations/row_gather.cpp +++ b/src/duckdb/src/common/row_operations/row_gather.cpp @@ -96,8 +96,8 @@ static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vec auto ptrs = FlatVector::GetData(rows); // Build the gather locations - auto data_locations = make_unsafe_uniq_array(count); - auto mask_locations = make_unsafe_uniq_array(count); + auto data_locations = make_unsafe_uniq_array_uninitialized(count); + auto mask_locations = make_unsafe_uniq_array_uninitialized(count); for (idx_t i = 0; i < count; i++) { auto row_idx = row_sel.get_index(i); auto row = ptrs[row_idx]; diff --git a/src/duckdb/src/common/row_operations/row_matcher.cpp b/src/duckdb/src/common/row_operations/row_matcher.cpp index 23b16946..41b9f211 100644 --- a/src/duckdb/src/common/row_operations/row_matcher.cpp +++ b/src/duckdb/src/common/row_operations/row_matcher.cpp @@ -8,10 +8,10 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; -template -static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, - const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, - const vector &, SelectionVector *no_match_sel, idx_t &no_match_count) { +template +static idx_t TemplatedMatchLoop(const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, + const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, + SelectionVector *no_match_sel, idx_t &no_match_count) { using COMPARISON_OP = ComparisonOperationWrapper; // LHS @@ -31,7 +31,7 @@ static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, S const auto idx = sel.get_index(i); const auto lhs_idx = lhs_sel.get_index(idx); - const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); + const auto lhs_null = LHS_ALL_VALID ? false : !lhs_validity.RowIsValid(lhs_idx); const auto &rhs_location = rhs_locations[idx]; const ValidityBytes rhs_mask(rhs_location); @@ -47,6 +47,19 @@ static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, S return match_count; } +template +static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, + const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, + const vector &, SelectionVector *no_match_sel, idx_t &no_match_count) { + if (lhs_format.unified.validity.AllValid()) { + return TemplatedMatchLoop(lhs_format, sel, count, rhs_layout, rhs_row_locations, + col_idx, no_match_sel, no_match_count); + } else { + return TemplatedMatchLoop(lhs_format, sel, count, rhs_layout, rhs_row_locations, + col_idx, no_match_sel, no_match_count); + } +} + template static idx_t StructMatchEquality(Vector &lhs_vector, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, @@ -198,6 +211,22 @@ void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layo } } +void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates, + vector &columns) { + + // The columns must have the same size as the predicates vector + D_ASSERT(columns.size() == predicates.size()); + + // The largest column_id must be smaller than the number of types to not cause an out-of-bounds error + D_ASSERT(*max_element(columns.begin(), columns.end()) < layout.GetTypes().size()); + + match_functions.reserve(predicates.size()); + for (idx_t idx = 0; idx < predicates.size(); idx++) { + column_t col_idx = columns[idx]; + match_functions.push_back(GetMatchFunction(no_match_sel, layout.GetTypes()[col_idx], predicates[idx])); + } +} + idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, SelectionVector *no_match_sel, idx_t &no_match_count) { @@ -211,6 +240,30 @@ idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs return count; } +idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, + idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, + SelectionVector *no_match_sel, idx_t &no_match_count, const vector &columns) { + D_ASSERT(!match_functions.empty()); + + // The column_ids must have the same size as the match_functions vector + D_ASSERT(columns.size() == match_functions.size()); + + // The largest column_id must be smaller than the number columns to not cause an out-of-bounds error + D_ASSERT(*max_element(columns.begin(), columns.end()) < lhs.ColumnCount()); + + for (idx_t fun_idx = 0; fun_idx < match_functions.size(); fun_idx++) { + // if we only care about specific columns, we need to use the column_ids to get the correct column index + // otherwise, we just use the fun_idx + const auto col_idx = columns[fun_idx]; + + const auto &match_function = match_functions[fun_idx]; + count = + match_function.function(lhs.data[col_idx], lhs_formats[col_idx], sel, count, rhs_layout, rhs_row_locations, + col_idx, match_function.child_functions, no_match_sel, no_match_count); + } + return count; +} + MatchFunction RowMatcher::GetMatchFunction(const bool no_match_sel, const LogicalType &type, const ExpressionType predicate) { return no_match_sel ? GetMatchFunction(type, predicate) : GetMatchFunction(type, predicate); diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp index 73242d13..95a230ef 100644 --- a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp +++ b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp @@ -114,62 +114,67 @@ void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const Selecti for (idx_t i = 0; i < add_count; i++) { auto idx = sel.get_index(i); auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t key_location = key_locations[i] + 1; + data_ptr_t &key_location = key_locations[i]; + const data_ptr_t key_location_start = key_location; // write validity and according value if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - key_locations[i]++; + *key_location++ = valid; auto &list_entry = list_data[source_idx]; if (list_entry.length > 0) { // denote that the list is not empty with a 1 - key_locations[i][0] = 1; - key_locations[i]++; + *key_location++ = 1; RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, key_locations + i, false, true, false, prefix_len, width - 2, list_entry.offset); } else { // denote that the list is empty with a 0 - key_locations[i][0] = 0; - key_locations[i]++; - memset(key_locations[i], '\0', width - 2); + *key_location++ = 0; + // mark rest of bits as empty + memset(key_location, '\0', width - 2); + key_location += width - 2; } // invert bits if desc if (desc) { - for (idx_t s = 0; s < width - 1; s++) { - *(key_location + s) = ~*(key_location + s); + // skip over validity byte, handled by nulls first/last + for (key_location = key_location_start + 1; key_location < key_location_start + width; + key_location++) { + *key_location = ~*key_location; } } } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - key_locations[i] += width; + *key_location++ = invalid; + memset(key_location, '\0', width - 1); + key_location += width - 1; } + D_ASSERT(key_location == key_location_start + width); } } else { for (idx_t i = 0; i < add_count; i++) { auto idx = sel.get_index(i); auto source_idx = vdata.sel->get_index(idx) + offset; auto &list_entry = list_data[source_idx]; - data_ptr_t key_location = key_locations[i]; + data_ptr_t &key_location = key_locations[i]; + const data_ptr_t key_location_start = key_location; if (list_entry.length > 0) { // denote that the list is not empty with a 1 - key_locations[i][0] = 1; - key_locations[i]++; + *key_location++ = 1; RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, key_locations + i, false, true, false, prefix_len, width - 1, list_entry.offset); } else { // denote that the list is empty with a 0 - key_locations[i][0] = 0; - key_locations[i]++; - memset(key_locations[i], '\0', width - 1); + *key_location++ = 0; + // mark rest of bits as empty + memset(key_location, '\0', width - 1); + key_location += width - 1; } // invert bits if desc if (desc) { - for (idx_t s = 0; s < width; s++) { - *(key_location + s) = ~*(key_location + s); + for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { + *key_location = ~*key_location; } } + D_ASSERT(key_location == key_location_start + width); } } } @@ -177,7 +182,9 @@ void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const Selecti void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values + auto &child_vector = ArrayVector::GetEntry(v); + auto array_size = ArrayType::GetSize(v.GetType()); + if (has_null) { auto &validity = vdata.validity; const data_t valid = nulls_first ? 1 : 0; @@ -186,33 +193,48 @@ void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount for (idx_t i = 0; i < add_count; i++) { auto idx = sel.get_index(i); auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value + data_ptr_t &key_location = key_locations[i]; + const data_ptr_t key_location_start = key_location; + if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; + *key_location++ = valid; + + auto array_offset = source_idx * array_size; + RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, + key_locations + i, false, true, false, prefix_len, width - 1, array_offset); + + // invert bits if desc + if (desc) { + // skip over validity byte, handled by nulls first/last + for (key_location = key_location_start + 1; key_location < key_location_start + width; + key_location++) { + *key_location = ~*key_location; + } + } } else { - key_locations[i][0] = invalid; + *key_location++ = invalid; + memset(key_location, '\0', width - 1); + key_location += width - 1; } - key_locations[i]++; + D_ASSERT(key_location == key_location_start + width); } - width--; - } - - // serialize the inner child - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto array_offset = source_idx * array_size; - data_ptr_t key_location = key_locations[i]; + } else { + for (idx_t i = 0; i < add_count; i++) { + auto idx = sel.get_index(i); + auto source_idx = vdata.sel->get_index(idx) + offset; + data_ptr_t &key_location = key_locations[i]; + const data_ptr_t key_location_start = key_location; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < width; s++) { - *(key_location + s) = ~*(key_location + s); + auto array_offset = source_idx * array_size; + RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, + key_locations + i, false, true, false, prefix_len, width, array_offset); + // invert bits if desc + if (desc) { + for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { + *key_location = ~*key_location; + } } + D_ASSERT(key_location == key_location_start + width); } } } @@ -256,6 +278,14 @@ void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcoun void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, idx_t prefix_len, idx_t width, idx_t offset) { +#ifdef DEBUG + // initialize to verify written width later + auto key_locations_copy = make_uniq_array(ser_count); + for (idx_t i = 0; i < ser_count; i++) { + key_locations_copy[i] = key_locations[i]; + } +#endif + UnifiedVectorFormat vdata; v.ToUnifiedFormat(vcount, vdata); switch (v.GetType().InternalType()) { @@ -317,6 +347,12 @@ void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector default: throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); } + +#ifdef DEBUG + for (idx_t i = 0; i < ser_count; i++) { + D_ASSERT(key_locations[i] == key_locations_copy[i] + width); + } +#endif } } // namespace duckdb diff --git a/src/duckdb/src/common/serializer/binary_deserializer.cpp b/src/duckdb/src/common/serializer/binary_deserializer.cpp index a13e3578..0fec7bea 100644 --- a/src/duckdb/src/common/serializer/binary_deserializer.cpp +++ b/src/duckdb/src/common/serializer/binary_deserializer.cpp @@ -113,7 +113,7 @@ string BinaryDeserializer::ReadString() { if (len == 0) { return string(); } - auto buffer = make_unsafe_uniq_array(len); + auto buffer = make_unsafe_uniq_array_uninitialized(len); ReadData(buffer.get(), len); return string(const_char_ptr_cast(buffer.get()), len); } diff --git a/src/duckdb/src/common/serializer/buffered_file_reader.cpp b/src/duckdb/src/common/serializer/buffered_file_reader.cpp index a6ed87e9..96cd2e08 100644 --- a/src/duckdb/src/common/serializer/buffered_file_reader.cpp +++ b/src/duckdb/src/common/serializer/buffered_file_reader.cpp @@ -1,21 +1,23 @@ #include "duckdb/common/serializer/buffered_file_reader.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" + #include "duckdb/common/exception.hpp" +#include "duckdb/common/serializer/buffered_file_writer.hpp" -#include #include +#include namespace duckdb { BufferedFileReader::BufferedFileReader(FileSystem &fs, const char *path, FileLockType lock_type, optional_ptr opener) - : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), total_read(0) { + : fs(fs), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), read_data(0), + total_read(0) { handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | lock_type, opener.get()); file_size = NumericCast(fs.GetFileSize(*handle)); } BufferedFileReader::BufferedFileReader(FileSystem &fs, unique_ptr handle_p) - : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), + : fs(fs), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), read_data(0), handle(std::move(handle_p)), total_read(0) { file_size = NumericCast(fs.GetFileSize(*handle)); } diff --git a/src/duckdb/src/common/serializer/buffered_file_writer.cpp b/src/duckdb/src/common/serializer/buffered_file_writer.cpp index be4f51fc..a378c4ac 100644 --- a/src/duckdb/src/common/serializer/buffered_file_writer.cpp +++ b/src/duckdb/src/common/serializer/buffered_file_writer.cpp @@ -1,7 +1,9 @@ #include "duckdb/common/serializer/buffered_file_writer.hpp" -#include "duckdb/common/exception.hpp" + #include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" #include "duckdb/common/typedefs.hpp" + #include namespace duckdb { @@ -10,12 +12,13 @@ namespace duckdb { constexpr FileOpenFlags BufferedFileWriter::DEFAULT_OPEN_FLAGS; BufferedFileWriter::BufferedFileWriter(FileSystem &fs, const string &path_p, FileOpenFlags open_flags) - : fs(fs), path(path_p), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), total_written(0) { + : fs(fs), path(path_p), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), + total_written(0) { handle = fs.OpenFile(path, open_flags | FileLockType::WRITE_LOCK); } -int64_t BufferedFileWriter::GetFileSize() { - return fs.GetFileSize(*handle) + NumericCast(offset); +idx_t BufferedFileWriter::GetFileSize() { + return NumericCast(fs.GetFileSize(*handle)) + offset; } idx_t BufferedFileWriter::GetTotalWritten() { @@ -65,20 +68,26 @@ void BufferedFileWriter::Flush() { offset = 0; } +void BufferedFileWriter::Close() { + Flush(); + handle->Close(); + handle.reset(); +} + void BufferedFileWriter::Sync() { Flush(); handle->Sync(); } -void BufferedFileWriter::Truncate(int64_t size) { - auto persistent = fs.GetFileSize(*handle); - D_ASSERT(size <= persistent + NumericCast(offset)); +void BufferedFileWriter::Truncate(idx_t size) { + auto persistent = NumericCast(fs.GetFileSize(*handle)); + D_ASSERT(size <= persistent + offset); if (persistent <= size) { // truncating into the pending write buffer. - offset = NumericCast(size - persistent); + offset = size - persistent; } else { // truncate the physical file on disk - handle->Truncate(size); + handle->Truncate(NumericCast(size)); // reset anything written in the buffer offset = 0; } diff --git a/src/duckdb/src/common/serializer/memory_stream.cpp b/src/duckdb/src/common/serializer/memory_stream.cpp index 1fd0ff81..a4869696 100644 --- a/src/duckdb/src/common/serializer/memory_stream.cpp +++ b/src/duckdb/src/common/serializer/memory_stream.cpp @@ -3,6 +3,7 @@ namespace duckdb { MemoryStream::MemoryStream(idx_t capacity) : position(0), capacity(capacity), owns_data(true) { + D_ASSERT(capacity != 0 && IsPowerOfTwo(capacity)); auto data_malloc_result = malloc(capacity); if (!data_malloc_result) { throw std::bad_alloc(); diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp index 5346d057..d87e31fd 100644 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ b/src/duckdb/src/common/sort/partition_state.cpp @@ -11,7 +11,7 @@ namespace duckdb { PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, const Orders &orders, const Types &payload_types, bool external) - : count(0), batch_base(0) { + : count(0) { RowLayout payload_layout; payload_layout.Initialize(payload_types); @@ -94,7 +94,7 @@ PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); external = ClientConfig::GetConfig(context).force_external; - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * idx_t(Storage::BLOCK_ALLOC_SIZE))); + const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); while (max_bits < 10 && (thread_pages >> max_bits) > 1) { ++max_bits; } @@ -316,9 +316,10 @@ void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { // No sorts, so build paged row chunks if (!rows) { const auto entry_size = payload_layout.GetRowWidth(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, (Storage::BLOCK_SIZE / entry_size) + 1); + const auto block_size = gstate.buffer_manager.GetBlockSize(); + const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, block_size / entry_size + 1); rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); + strings = make_uniq(gstate.buffer_manager, block_size, 1U, true); } const auto row_count = input_chunk.size(); const auto row_sel = FlatVector::IncrementalSelectionVector(); @@ -401,11 +402,11 @@ void PartitionLocalSinkState::Combine() { PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), memory_per_thread(sink.memory_per_thread), + : sink(sink), group_data(std::move(group_data_p)), group_idx(sink.hash_groups.size()), + memory_per_thread(sink.memory_per_thread), num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - const auto group_idx = sink.hash_groups.size(); auto new_group = make_uniq(sink.buffer_manager, sink.partitions, sink.orders, sink.payload_types, sink.external); sink.hash_groups.emplace_back(std::move(new_group)); @@ -423,12 +424,11 @@ PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &s } PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), memory_per_thread(sink.memory_per_thread), + : sink(sink), group_idx(0), memory_per_thread(sink.memory_per_thread), num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { const hash_t hash_bin = 0; - const size_t group_idx = 0; hash_group = sink.hash_groups[group_idx].get(); global_sort = sink.hash_groups[group_idx]->global_sort.get(); @@ -448,6 +448,10 @@ void PartitionLocalMergeState::Merge() { merge_sorter.PerformInMergeRound(); } +void PartitionLocalMergeState::Sorted() { + merge_state->sink.OnSortedPartition(merge_state->group_idx); +} + void PartitionLocalMergeState::ExecuteTask() { switch (stage) { case PartitionSortStage::SCAN: @@ -459,6 +463,9 @@ void PartitionLocalMergeState::ExecuteTask() { case PartitionSortStage::MERGE: Merge(); break; + case PartitionSortStage::SORTED: + Sorted(); + break; default: throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); } @@ -513,30 +520,36 @@ bool PartitionGlobalMergeState::TryPrepareNextStage() { return true; case PartitionSortStage::PREPARE: - total_tasks = global_sort->sorted_blocks.size() / 2; - if (!total_tasks) { + if (!(global_sort->sorted_blocks.size() / 2)) { break; } stage = PartitionSortStage::MERGE; global_sort->InitializeMergeRound(); + total_tasks = num_threads; return true; case PartitionSortStage::MERGE: global_sort->CompleteMergeRound(true); - total_tasks = global_sort->sorted_blocks.size() / 2; - if (!total_tasks) { + if (!(global_sort->sorted_blocks.size() / 2)) { break; } global_sort->InitializeMergeRound(); + total_tasks = num_threads; return true; case PartitionSortStage::SORTED: - break; + stage = PartitionSortStage::FINISHED; + total_tasks = 0; + return false; + + case PartitionSortStage::FINISHED: + return false; } stage = PartitionSortStage::SORTED; + total_tasks = 1; - return false; + return true; } PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { @@ -559,13 +572,15 @@ PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState auto state = make_uniq(sink); states.emplace_back(std::move(state)); } + + sink.OnBeginMerge(); } class PartitionMergeTask : public ExecutorTask { public: PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate) - : ExecutorTask(context_p, std::move(event_p)), local_state(gstate), hash_groups(hash_groups_p) { + PartitionGlobalSinkState &gstate, const PhysicalOperator &op) + : ExecutorTask(context_p, std::move(event_p), op), local_state(gstate), hash_groups(hash_groups_p) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; @@ -602,7 +617,7 @@ bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_sta // Thread is done with its assigned task, try to fetch new work for (auto group = sorted; group < states.size(); ++group) { auto &global_state = states[group]; - if (global_state->IsSorted()) { + if (global_state->IsFinished()) { // This hash group is done // Update the high water mark of densely completed groups if (sorted == group) { @@ -665,7 +680,7 @@ void PartitionMergeEvent::Schedule() { vector> merge_tasks; for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate)); + merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate, op)); } SetTasks(std::move(merge_tasks)); } diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp index a917631a..b193cee6 100644 --- a/src/duckdb/src/common/sort/radix_sort.cpp +++ b/src/duckdb/src/common/sort/radix_sort.cpp @@ -17,7 +17,7 @@ static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t datapt return; } // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array(end - start); + auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); auto entry_ptrs = (data_ptr_t *)ptr_block.get(); for (idx_t i = start; i < end; i++) { entry_ptrs[i - start] = row_ptr; @@ -73,7 +73,7 @@ static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool * continue; } idx_t j; - for (j = i; j < count; j++) { + for (j = i + 1; j < count; j++) { if (!ties[j]) { break; } @@ -158,7 +158,7 @@ inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; if (count > 1) { const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array(row_width); + auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); const data_ptr_t val = temp_val.get(); const auto comp_width = total_comp_width - offset; for (idx_t i = 1; i < count; i++) { @@ -238,22 +238,29 @@ void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const id //! Calls different sort functions, depending on the count and sorting sizes void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { + if (contains_string) { auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); auto end = begin + count; duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } else if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } else if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } else { - auto temp_block = buffer_manager.Allocate(MemoryTag::ORDER_BY, - MaxValue(count * sort_layout.entry_size, (idx_t)Storage::BLOCK_SIZE)); - auto preallocated_array = make_unsafe_uniq_array(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - preallocated_array.get(), false); + return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); + } + + if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { + return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); } + + if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { + return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); + } + + const auto block_size = buffer_manager.GetBlockSize(); + auto temp_block = + buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); + auto pre_allocated_array = + make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); + RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, + pre_allocated_array.get(), false); } //! Identifies sequences of rows that are tied, and calls radix sort on these @@ -306,7 +313,7 @@ void LocalSortState::SortInMemory() { if (!ties) { // This is the first sort RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array(count); + ties_ptr = make_unsafe_uniq_array_uninitialized(count); ties = ties_ptr.get(); std::fill_n(ties, count - 1, true); ties[count - 1] = false; diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp index 27650b46..386f3498 100644 --- a/src/duckdb/src/common/sort/sort_state.cpp +++ b/src/duckdb/src/common/sort/sort_state.cpp @@ -163,22 +163,25 @@ void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManage sort_layout = &global_sort_state.sort_layout; payload_layout = &global_sort_state.payload_layout; buffer_manager = &buffer_manager_p; + const auto block_size = buffer_manager->GetBlockSize(); + // Radix sorting data - radix_sorting_data = make_uniq( - *buffer_manager, RowDataCollection::EntriesPerBlock(sort_layout->entry_size), sort_layout->entry_size); + auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); + radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); + // Blob sorting data if (!sort_layout->all_constant) { auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - blob_sorting_data = make_uniq( - *buffer_manager, RowDataCollection::EntriesPerBlock(blob_row_width), blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); + entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); + blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); + blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); } + // Payload data auto payload_row_width = payload_layout->GetRowWidth(); - payload_data = make_uniq(*buffer_manager, RowDataCollection::EntriesPerBlock(payload_row_width), - payload_row_width); - payload_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); - // Init done + entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); + payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); + payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); initialized = true; } @@ -266,17 +269,17 @@ unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &ro return new_block; } // Create block with the correct capacity - auto buffer_manager = &row_data.buffer_manager; + auto &buffer_manager = row_data.buffer_manager; const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, capacity, entry_size); + idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); + auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); new_block->count = row_data.count; - auto new_block_handle = buffer_manager->Pin(new_block->block); + auto new_block_handle = buffer_manager.Pin(new_block->block); data_ptr_t new_block_ptr = new_block_handle.Ptr(); // Copy the data of the blocks into a single block for (idx_t i = 0; i < row_data.blocks.size(); i++) { auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager->Pin(block->block); + auto block_handle = buffer_manager.Pin(block->block); memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); new_block_ptr += block->count * entry_size; block.reset(); @@ -322,7 +325,7 @@ void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataColl idx_t total_byte_offset = std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, (idx_t)Storage::BLOCK_SIZE); + idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); ordered_heap_block->count = count; ordered_heap_block->byte_offset = total_byte_offset; @@ -401,7 +404,7 @@ void GlobalSortState::PrepareMergePhase() { idx_t total_heap_size = std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size > 0.25 * buffer_manager.GetQueryMaxMemory())) { + if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { external = true; } // Use the data that we have to determine which partition size to use during the merge diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp index 9539302c..c4766c95 100644 --- a/src/duckdb/src/common/sort/sorted_block.cpp +++ b/src/duckdb/src/common/sort/sorted_block.cpp @@ -25,12 +25,11 @@ idx_t SortedData::Count() { } void SortedData::CreateBlock() { - auto capacity = - MaxValue(((idx_t)Storage::BLOCK_SIZE + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); + const auto block_size = buffer_manager.GetBlockSize(); + auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); if (!layout.AllConstant() && state.external) { - heap_blocks.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U)); + heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); D_ASSERT(data_blocks.size() == heap_blocks.size()); } } @@ -104,8 +103,8 @@ void SortedBlock::InitializeWrite() { } void SortedBlock::CreateBlock() { - auto capacity = MaxValue(((idx_t)Storage::BLOCK_SIZE + sort_layout.entry_size - 1) / sort_layout.entry_size, - state.block_capacity); + const auto block_size = buffer_manager.GetBlockSize(); + auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); radix_sorting_data.push_back( make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); } @@ -289,12 +288,13 @@ void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { auto count = sorted_data.Count(); auto &layout = sorted_data.layout; + const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); + rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); rows->count = count; - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); + heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); if (!sorted_data.layout.AllConstant()) { heap->count = count; } @@ -328,9 +328,10 @@ PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_i auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; auto count = sorted_data.data_blocks[block_idx]->count; auto &layout = sorted_data.layout; + const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); + rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); if (flush_p) { rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); } else { @@ -338,7 +339,7 @@ PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_i } rows->count = count; - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); + heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { if (flush_p) { heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index 6177129c..dd57bda7 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/to_string.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/function/scalar/string_functions.hpp" +#include "jaro_winkler.hpp" #include #include @@ -27,9 +28,8 @@ string StringUtil::GenerateRandomName(idx_t length) { std::uniform_int_distribution<> dis(0, 15); std::stringstream ss; - ss << std::hex; for (idx_t i = 0; i < length; i++) { - ss << dis(gen); + ss << "0123456789abcdef"[dis(gen)]; } return ss.str(); } @@ -212,6 +212,26 @@ string StringUtil::Lower(const string &str) { return (copy); } +string StringUtil::Title(const string &str) { + string copy; + bool first_character = true; + for (auto c : str) { + bool is_alpha = StringUtil::CharacterIsAlpha(c); + if (is_alpha) { + if (first_character) { + copy += StringUtil::CharacterToUpper(c); + first_character = false; + } else { + copy += StringUtil::CharacterToLower(c); + } + } else { + first_character = true; + copy += c; + } + } + return copy; +} + bool StringUtil::IsLower(const string &str) { return str == Lower(str); } @@ -308,17 +328,17 @@ string StringUtil::Replace(string source, const string &from, const string &to) return source; } -vector StringUtil::TopNStrings(vector> scores, idx_t n, idx_t threshold) { +vector StringUtil::TopNStrings(vector> scores, idx_t n, double threshold) { if (scores.empty()) { return vector(); } - sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { - return a.second < b.second || (a.second == b.second && a.first.size() < b.first.size()); + sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { + return a.second > b.second || (a.second == b.second && a.first.size() < b.first.size()); }); vector result; result.push_back(scores[0].first); for (idx_t i = 1; i < MinValue(scores.size(), n); i++) { - if (scores[i].second > threshold) { + if (scores[i].second < threshold) { break; } result.push_back(scores[i].first); @@ -326,6 +346,27 @@ vector StringUtil::TopNStrings(vector> scores, idx_t return result; } +static double NormalizeScore(idx_t score, idx_t max_score) { + return 1.0 - static_cast(score) / static_cast(max_score); +} + +vector StringUtil::TopNStrings(const vector> &scores, idx_t n, idx_t threshold) { + // obtain the max score to normalize + idx_t max_score = threshold; + for (auto &score : scores) { + if (score.second > max_score) { + max_score = score.second; + } + } + + // normalize + vector> normalized_scores; + for (auto &score : scores) { + normalized_scores.push_back(make_pair(score.first, NormalizeScore(score.second, max_score))); + } + return TopNStrings(std::move(normalized_scores), n, NormalizeScore(threshold, max_score)); +} + struct LevenshteinArray { LevenshteinArray(idx_t len1, idx_t len2) : len1(len1) { dist = make_unsafe_uniq_array(len1 * len2); @@ -385,6 +426,11 @@ idx_t StringUtil::SimilarityScore(const string &s1, const string &s2) { return LevenshteinDistance(s1, s2, 3); } +double StringUtil::SimilarityRating(const string &s1, const string &s2) { + return duckdb_jaro_winkler::jaro_winkler_similarity(s1.data(), s1.data() + s1.size(), s2.data(), + s2.data() + s2.size()); +} + vector StringUtil::TopNLevenshtein(const vector &strings, const string &target, idx_t n, idx_t threshold) { vector> scores; @@ -399,6 +445,16 @@ vector StringUtil::TopNLevenshtein(const vector &strings, const return TopNStrings(scores, n, threshold); } +vector StringUtil::TopNJaroWinkler(const vector &strings, const string &target, idx_t n, + double threshold) { + vector> scores; + scores.reserve(strings.size()); + for (auto &str : strings) { + scores.emplace_back(str, SimilarityRating(str, target)); + } + return TopNStrings(scores, n, threshold); +} + string StringUtil::CandidatesMessage(const vector &candidates, const string &candidate) { string result_str; if (!candidates.empty()) { @@ -472,8 +528,8 @@ string StringUtil::ToJSONMap(ExceptionType type, const string &message, const un yyjson_write_err err; size_t len; - yyjson_write_flag flags = YYJSON_WRITE_ALLOW_INVALID_UNICODE; - const char *json = yyjson_mut_write_opts(doc, flags, nullptr, &len, &err); + constexpr yyjson_write_flag flags = YYJSON_WRITE_ALLOW_INVALID_UNICODE; + char *json = yyjson_mut_write_opts(doc, flags, nullptr, &len, &err); if (!json) { yyjson_mut_doc_free(doc); throw SerializationException("Failed to write JSON string: %s", err.msg); @@ -482,7 +538,7 @@ string StringUtil::ToJSONMap(ExceptionType type, const string &message, const un string result(json, len); // Free the JSON and the document - free((void *)json); + free(json); yyjson_mut_doc_free(doc); // Return the result @@ -537,7 +593,6 @@ string StringUtil::GetFileStem(const string &file_name) { } string StringUtil::GetFilePath(const string &file_path) { - // Trim the trailing slashes auto end = file_path.size() - 1; while (end > 0 && (file_path[end] == '/' || file_path[end] == '\\')) { @@ -556,4 +611,106 @@ string StringUtil::GetFilePath(const string &file_path) { return file_path.substr(0, pos + 1); } +struct URLEncodeLength { + using RESULT_TYPE = idx_t; + + static void ProcessCharacter(idx_t &result, char) { + result++; + } + + static void ProcessHex(idx_t &result, const char *, idx_t) { + result++; + } +}; + +struct URLEncodeWrite { + using RESULT_TYPE = char *; + + static void ProcessCharacter(char *&result, char c) { + *result = c; + result++; + } + + static void ProcessHex(char *&result, const char *input, idx_t idx) { + uint32_t hex_first = StringUtil::GetHexValue(input[idx + 1]); + uint32_t hex_second = StringUtil::GetHexValue(input[idx + 2]); + uint32_t hex_value = (hex_first << 4) + hex_second; + ProcessCharacter(result, static_cast(hex_value)); + } +}; + +template +void URLEncodeInternal(const char *input, idx_t input_size, typename OP::RESULT_TYPE &result, bool encode_slash) { + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + static const char *HEX_DIGIT = "0123456789ABCDEF"; + for (idx_t i = 0; i < input_size; i++) { + char ch = input[i]; + if ((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || + ch == '-' || ch == '~' || ch == '.') { + OP::ProcessCharacter(result, ch); + } else if (ch == '/' && !encode_slash) { + OP::ProcessCharacter(result, ch); + } else { + OP::ProcessCharacter(result, '%'); + OP::ProcessCharacter(result, HEX_DIGIT[static_cast(ch) >> 4]); + OP::ProcessCharacter(result, HEX_DIGIT[static_cast(ch) & 15]); + } + } +} + +idx_t StringUtil::URLEncodeSize(const char *input, idx_t input_size, bool encode_slash) { + idx_t result_length = 0; + URLEncodeInternal(input, input_size, result_length, encode_slash); + return result_length; +} + +void StringUtil::URLEncodeBuffer(const char *input, idx_t input_size, char *output, bool encode_slash) { + URLEncodeInternal(input, input_size, output, encode_slash); +} + +string StringUtil::URLEncode(const string &input, bool encode_slash) { + idx_t result_size = URLEncodeSize(input.c_str(), input.size(), encode_slash); + auto result_data = make_uniq_array(result_size); + URLEncodeBuffer(input.c_str(), input.size(), result_data.get(), encode_slash); + return string(result_data.get(), result_size); +} + +template +void URLDecodeInternal(const char *input, idx_t input_size, typename OP::RESULT_TYPE &result, bool plus_to_space) { + for (idx_t i = 0; i < input_size; i++) { + char ch = input[i]; + if (plus_to_space && ch == '+') { + OP::ProcessCharacter(result, ' '); + } else if (ch == '%' && i + 2 < input_size && StringUtil::CharacterIsHex(input[i + 1]) && + StringUtil::CharacterIsHex(input[i + 2])) { + OP::ProcessHex(result, input, i); + i += 2; + } else { + OP::ProcessCharacter(result, ch); + } + } +} + +idx_t StringUtil::URLDecodeSize(const char *input, idx_t input_size, bool plus_to_space) { + idx_t result_length = 0; + URLDecodeInternal(input, input_size, result_length, plus_to_space); + return result_length; +} + +void StringUtil::URLDecodeBuffer(const char *input, idx_t input_size, char *output, bool plus_to_space) { + char *output_start = output; + URLDecodeInternal(input, input_size, output, plus_to_space); + if (!Utf8Proc::IsValid(output_start, NumericCast(output - output_start))) { + throw InvalidInputException("Failed to decode string \"%s\" using URL decoding - decoded value is invalid UTF8", + string(input, input_size)); + } +} + +string StringUtil::URLDecode(const string &input, bool plus_to_space) { + idx_t result_size = URLDecodeSize(input.c_str(), input.size(), plus_to_space); + auto result_data = make_uniq_array(result_size); + URLDecodeBuffer(input.c_str(), input.size(), result_data.get(), plus_to_space); + return string(result_data.get(), result_size); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer.cpp index aa295364..c924ff26 100644 --- a/src/duckdb/src/common/tree_renderer.cpp +++ b/src/duckdb/src/common/tree_renderer.cpp @@ -1,519 +1,27 @@ #include "duckdb/common/tree_renderer.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "utf8proc_wrapper.hpp" +#include "duckdb/common/tree_renderer/text_tree_renderer.hpp" +#include "duckdb/common/tree_renderer/json_tree_renderer.hpp" +#include "duckdb/common/tree_renderer/html_tree_renderer.hpp" +#include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" #include namespace duckdb { -RenderTree::RenderTree(idx_t width_p, idx_t height_p) : width(width_p), height(height_p) { - nodes = unique_ptr[]>(new unique_ptr[(width + 1) * (height + 1)]); -} - -RenderTreeNode *RenderTree::GetNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return nullptr; - } - return nodes[GetPosition(x, y)].get(); -} - -bool RenderTree::HasNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return false; - } - return nodes[GetPosition(x, y)].get() != nullptr; -} - -idx_t RenderTree::GetPosition(idx_t x, idx_t y) { - return y * width + x; -} - -void RenderTree::SetNode(idx_t x, idx_t y, unique_ptr node) { - nodes[GetPosition(x, y)] = std::move(node); -} - -void TreeRenderer::RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - if (root.HasNode(x, y)) { - ss << config.LTCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - if (y == 0) { - // top level node: no node above this one - ss << config.HORIZONTAL; - } else { - // render connection to node above this one - ss << config.DMIDDLE; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - ss << config.RTCORNER; - } else { - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - ss << '\n'; -} - -void TreeRenderer::RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x <= root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - if (root.HasNode(x, y)) { - ss << config.LDCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - if (root.HasNode(x, y + 1)) { - // node below this one: connect to that one - ss << config.TMIDDLE; - } else { - // no node below this one: end the box - ss << config.HORIZONTAL; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - ss << config.RDCORNER; - } else if (root.HasNode(x, y + 1)) { - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - ss << config.VERTICAL; - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } else { - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - ss << '\n'; -} - -string AdjustTextForRendering(string source, idx_t max_render_width) { - idx_t cpos = 0; - idx_t render_width = 0; - vector> render_widths; - while (cpos < source.size()) { - idx_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); - cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); - render_width += char_render_width; - render_widths.emplace_back(cpos, render_width); - if (render_width > max_render_width) { - break; - } - } - if (render_width > max_render_width) { - // need to find a position to truncate - for (idx_t pos = render_widths.size(); pos > 0; pos--) { - if (render_widths[pos - 1].second < max_render_width - 4) { - return source.substr(0, render_widths[pos - 1].first) + "..." + - string(max_render_width - render_widths[pos - 1].second - 3, ' '); - } - } - source = "..."; - } - // need to pad with spaces - idx_t total_spaces = max_render_width - render_width; - idx_t half_spaces = total_spaces / 2; - idx_t extra_left_space = total_spaces % 2 == 0 ? 0 : 1; - return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); -} - -static bool NodeHasMultipleChildren(RenderTree &root, idx_t x, idx_t y) { - for (; x < root.width && !root.HasNode(x + 1, y); x++) { - if (root.HasNode(x + 1, y + 1)) { - return true; - } - } - return false; -} - -void TreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { - // we first need to figure out how high our boxes are going to be - vector> extra_info; - idx_t extra_height = 0; - extra_info.resize(root.width); - for (idx_t x = 0; x < root.width; x++) { - auto node = root.GetNode(x, y); - if (node) { - SplitUpExtraInfo(node->extra_text, extra_info[x]); - if (extra_info[x].size() > extra_height) { - extra_height = extra_info[x].size(); - } - } - } - extra_height = MinValue(extra_height, config.max_extra_lines); - idx_t halfway_point = (extra_height + 1) / 2; - // now we render the actual node - for (idx_t render_y = 0; render_y <= extra_height; render_y++) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - auto node = root.GetNode(x, y); - if (!node) { - if (render_y == halfway_point) { - bool has_child_to_the_right = NodeHasMultipleChildren(root, x, y); - if (root.HasNode(x, y + 1)) { - // node right below this one - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); - ss << config.RTCORNER; - if (has_child_to_the_right) { - // but we have another child to the right! keep rendering the line - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); - } else { - // only a child below this one: fill the rest with spaces - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } - } else if (has_child_to_the_right) { - // child to the right, but no child right below this one: render a full line - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width); - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } else if (render_y >= halfway_point) { - if (root.HasNode(x, y + 1)) { - // we have a node below this empty spot: render a vertical line - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - ss << config.VERTICAL; - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } else { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } else { - ss << config.VERTICAL; - // figure out what to render - string render_text; - if (render_y == 0) { - render_text = node->name; - } else { - if (render_y <= extra_info[x].size()) { - render_text = extra_info[x][render_y - 1]; - } - } - render_text = AdjustTextForRendering(render_text, config.node_render_width - 2); - ss << render_text; - - if (render_y == halfway_point && NodeHasMultipleChildren(root, x, y)) { - ss << config.LMIDDLE; - } else { - ss << config.VERTICAL; - } - } - } - ss << '\n'; - } -} - -string TreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const QueryProfiler::TreeNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void TreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const QueryProfiler::TreeNode &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = CreateTree(op); - ToStream(*tree, ss); -} - -void TreeRenderer::ToStream(RenderTree &root, std::ostream &ss) { - while (root.width * config.node_render_width > config.maximum_render_width) { - if (config.node_render_width - 2 < config.minimum_render_width) { - break; - } - config.node_render_width -= 2; - } - - for (idx_t y = 0; y < root.height; y++) { - // start by rendering the top layer - RenderTopLayer(root, ss, y); - // now we render the content of the boxes - RenderBoxContent(root, ss, y); - // render the bottom layer of each of the boxes - RenderBottomLayer(root, ss, y); - } -} - -bool TreeRenderer::CanSplitOnThisChar(char l) { - return (l < '0' || (l > '9' && l < 'A') || (l > 'Z' && l < 'a')) && l != '_'; -} - -bool TreeRenderer::IsPadding(char l) { - return l == ' ' || l == '\t' || l == '\n' || l == '\r'; -} - -string TreeRenderer::RemovePadding(string l) { - idx_t start = 0, end = l.size(); - while (start < l.size() && IsPadding(l[start])) { - start++; - } - while (end > 0 && IsPadding(l[end - 1])) { - end--; - } - return l.substr(start, end - start); -} - -void TreeRenderer::SplitStringBuffer(const string &source, vector &result) { - D_ASSERT(Utf8Proc::IsValid(source.c_str(), source.size())); - idx_t max_line_render_size = config.node_render_width - 2; - // utf8 in prompt, get render width - idx_t cpos = 0; - idx_t start_pos = 0; - idx_t render_width = 0; - idx_t last_possible_split = 0; - while (cpos < source.size()) { - // check if we can split on this character - if (CanSplitOnThisChar(source[cpos])) { - last_possible_split = cpos; - } - size_t char_render_width = Utf8Proc::RenderWidth(source.c_str(), source.size(), cpos); - idx_t next_cpos = Utf8Proc::NextGraphemeCluster(source.c_str(), source.size(), cpos); - if (render_width + char_render_width > max_line_render_size) { - if (last_possible_split <= start_pos + 8) { - last_possible_split = cpos; - } - result.push_back(source.substr(start_pos, last_possible_split - start_pos)); - start_pos = last_possible_split; - cpos = last_possible_split; - render_width = 0; - } - cpos = next_cpos; - render_width += char_render_width; - } - if (source.size() > start_pos) { - result.push_back(source.substr(start_pos, source.size() - start_pos)); - } -} - -void TreeRenderer::SplitUpExtraInfo(const string &extra_info, vector &result) { - if (extra_info.empty()) { - return; - } - if (!Utf8Proc::IsValid(extra_info.c_str(), extra_info.size())) { - return; - } - auto splits = StringUtil::Split(extra_info, "\n"); - if (!splits.empty() && splits[0] != "[INFOSEPARATOR]") { - result.push_back(ExtraInfoSeparator()); - } - for (auto &split : splits) { - if (split == "[INFOSEPARATOR]") { - result.push_back(ExtraInfoSeparator()); - continue; - } - string str = RemovePadding(split); - if (str.empty()) { - continue; - } - SplitStringBuffer(str, result); - } -} - -string TreeRenderer::ExtraInfoSeparator() { - return StringUtil::Repeat(string(config.HORIZONTAL) + " ", (config.node_render_width - 7) / 2); -} - -unique_ptr TreeRenderer::CreateRenderNode(string name, string extra_info) { - auto result = make_uniq(); - result->name = std::move(name); - result->extra_text = std::move(extra_info); - return result; -} - -class TreeChildrenIterator { -public: - template - static bool HasChildren(const T &op) { - return !op.children.empty(); - } - template - static void Iterate(const T &op, const std::function &callback) { - for (auto &child : op.children) { - callback(*child); - } - } -}; - -template <> -bool TreeChildrenIterator::HasChildren(const PhysicalOperator &op) { - switch (op.type) { - case PhysicalOperatorType::LEFT_DELIM_JOIN: - case PhysicalOperatorType::RIGHT_DELIM_JOIN: - case PhysicalOperatorType::POSITIONAL_SCAN: - return true; +unique_ptr TreeRenderer::CreateRenderer(ExplainFormat format) { + switch (format) { + case ExplainFormat::DEFAULT: + case ExplainFormat::TEXT: + return make_uniq(); + case ExplainFormat::JSON: + return make_uniq(); + case ExplainFormat::HTML: + return make_uniq(); + case ExplainFormat::GRAPHVIZ: + return make_uniq(); default: - return !op.children.empty(); - } -} -template <> -void TreeChildrenIterator::Iterate(const PhysicalOperator &op, - const std::function &callback) { - for (auto &child : op.children) { - callback(*child); - } - if (op.type == PhysicalOperatorType::LEFT_DELIM_JOIN || op.type == PhysicalOperatorType::RIGHT_DELIM_JOIN) { - auto &delim = op.Cast(); - callback(*delim.join); - } else if ((op.type == PhysicalOperatorType::POSITIONAL_SCAN)) { - auto &pscan = op.Cast(); - for (auto &table : pscan.child_tables) { - callback(*table); - } - } -} - -struct PipelineRenderNode { - explicit PipelineRenderNode(const PhysicalOperator &op) : op(op) { - } - - const PhysicalOperator &op; - unique_ptr child; -}; - -template <> -bool TreeChildrenIterator::HasChildren(const PipelineRenderNode &op) { - return op.child.get(); -} - -template <> -void TreeChildrenIterator::Iterate(const PipelineRenderNode &op, - const std::function &callback) { - if (op.child) { - callback(*op.child); - } -} - -template -static void GetTreeWidthHeight(const T &op, idx_t &width, idx_t &height) { - if (!TreeChildrenIterator::HasChildren(op)) { - width = 1; - height = 1; - return; - } - width = 0; - height = 0; - - TreeChildrenIterator::Iterate(op, [&](const T &child) { - idx_t child_width, child_height; - GetTreeWidthHeight(child, child_width, child_height); - width += child_width; - height = MaxValue(height, child_height); - }); - height++; -} - -template -idx_t TreeRenderer::CreateRenderTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y) { - auto node = TreeRenderer::CreateNode(op); - result.SetNode(x, y, std::move(node)); - - if (!TreeChildrenIterator::HasChildren(op)) { - return 1; - } - idx_t width = 0; - // render the children of this node - TreeChildrenIterator::Iterate( - op, [&](const T &child) { width += CreateRenderTreeRecursive(result, child, x + width, y + 1); }); - return width; -} - -template -unique_ptr TreeRenderer::CreateRenderTree(const T &op) { - idx_t width, height; - GetTreeWidthHeight(op, width, height); - - auto result = make_uniq(width, height); - - // now fill in the tree - CreateRenderTreeRecursive(*result, op, 0, 0); - return result; -} - -unique_ptr TreeRenderer::CreateNode(const LogicalOperator &op) { - return CreateRenderNode(op.GetName(), op.ParamsToString()); -} - -unique_ptr TreeRenderer::CreateNode(const PhysicalOperator &op) { - return CreateRenderNode(op.GetName(), op.ParamsToString()); -} - -unique_ptr TreeRenderer::CreateNode(const PipelineRenderNode &op) { - return CreateNode(op.op); -} - -unique_ptr TreeRenderer::CreateNode(const QueryProfiler::TreeNode &op) { - auto result = TreeRenderer::CreateRenderNode(op.name, op.extra_info); - result->extra_text += "\n[INFOSEPARATOR]"; - result->extra_text += "\n" + to_string(op.info.elements); - string timing = StringUtil::Format("%.2f", op.info.time); - result->extra_text += "\n(" + timing + "s)"; - return result; -} - -unique_ptr TreeRenderer::CreateTree(const LogicalOperator &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const PhysicalOperator &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const QueryProfiler::TreeNode &op) { - return CreateRenderTree(op); -} - -unique_ptr TreeRenderer::CreateTree(const Pipeline &pipeline) { - auto operators = pipeline.GetOperators(); - D_ASSERT(!operators.empty()); - unique_ptr node; - for (auto &op : operators) { - auto new_node = make_uniq(op.get()); - new_node->child = std::move(node); - node = std::move(new_node); + throw NotImplementedException("ExplainFormat %s not implemented", EnumUtil::ToString(format)); } - return CreateRenderTree(*node); } } // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp new file mode 100644 index 00000000..40a93b69 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp @@ -0,0 +1,108 @@ +#include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +string GRAPHVIZTreeRenderer::ToString(const LogicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string GRAPHVIZTreeRenderer::ToString(const PhysicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string GRAPHVIZTreeRenderer::ToString(const ProfilingNode &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string GRAPHVIZTreeRenderer::ToString(const Pipeline &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void GRAPHVIZTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void GRAPHVIZTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void GRAPHVIZTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void GRAPHVIZTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void GRAPHVIZTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + const string digraph_format = R"( +digraph G { + node [shape=box, style=rounded, fontname="Courier New", fontsize=10]; +%s +%s +} + )"; + + vector nodes; + vector edges; + + const string node_format = R"( node_%d_%d [label="%s"];)"; + + for (idx_t y = 0; y < root.height; y++) { + for (idx_t x = 0; x < root.width; x++) { + auto node = root.GetNode(x, y); + if (!node) { + continue; + } + + // Create Node + vector body; + body.push_back(node->name); + for (auto &item : node->extra_text) { + auto &key = item.first; + auto &value_raw = item.second; + + auto value = QueryProfiler::JSONSanitize(value_raw); + body.push_back(StringUtil::Format("%s:\\n%s", key, value)); + } + nodes.push_back(StringUtil::Format(node_format, x, y, StringUtil::Join(body, "\\n───\\n"))); + + // Create Edge(s) + for (auto &coord : node->child_positions) { + edges.push_back(StringUtil::Format(" node_%d_%d -> node_%d_%d;", x, y, coord.x, coord.y)); + } + } + } + auto node_lines = StringUtil::Join(nodes, "\n"); + auto edge_lines = StringUtil::Join(edges, "\n"); + + string result = StringUtil::Format(digraph_format, node_lines, edge_lines); + ss << result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp new file mode 100644 index 00000000..66a8f2d1 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp @@ -0,0 +1,267 @@ +#include "duckdb/common/tree_renderer/html_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +string HTMLTreeRenderer::ToString(const LogicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string HTMLTreeRenderer::ToString(const PhysicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string HTMLTreeRenderer::ToString(const ProfilingNode &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string HTMLTreeRenderer::ToString(const Pipeline &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void HTMLTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void HTMLTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void HTMLTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void HTMLTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +static string CreateStyleSection(RenderTree &root) { + return R"( + + )"; +} + +static string CreateHeadSection(RenderTree &root) { + string head_section = R"( + + + + + + + DuckDB Query Plan + %s + + )"; + return StringUtil::Format(head_section, CreateStyleSection(root)); +} + +static string CreateGridItemContent(RenderTreeNode &node) { + const string content_format = R"( +
+%s +
+ )"; + + vector items; + for (auto &item : node.extra_text) { + auto &key = item.first; + auto &value = item.second; + if (value.empty()) { + continue; + } + items.push_back(StringUtil::Format(R"(
%s
)", key)); + auto splits = StringUtil::Split(value, "\n"); + for (auto &split : splits) { + items.push_back(StringUtil::Format(R"(
%s
)", split)); + } + } + string result; + if (!items.empty()) { + result = StringUtil::Format(content_format, StringUtil::Join(items, "\n")); + } + if (!node.child_positions.empty()) { + result += ""; + } + return result; +} + +static string CreateGridItem(RenderTree &root, idx_t x, idx_t y) { + const string grid_item_format = R"( +
+
%s
%s +
+ )"; + + auto node = root.GetNode(x, y); + if (!node) { + return ""; + } + + auto title = node->name; + auto content = CreateGridItemContent(*node); + return StringUtil::Format(grid_item_format, title, content); +} + +static string CreateTreeRecursive(RenderTree &root, idx_t x, idx_t y) { + string result; + + result += "
  • "; + result += CreateGridItem(root, x, y); + auto node = root.GetNode(x, y); + if (!node->child_positions.empty()) { + result += "
      "; + for (auto &coord : node->child_positions) { + result += CreateTreeRecursive(root, coord.x, coord.y); + } + result += "
    "; + } + result += "
  • "; + return result; +} + +static string CreateBodySection(RenderTree &root) { + const string body_section = R"( + +
    +
      %s
    +
    + + + + + + )"; + return StringUtil::Format(body_section, CreateTreeRecursive(root, 0, 0)); +} + +void HTMLTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + string result; + result += CreateHeadSection(root); + result += CreateBodySection(root); + ss << result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp new file mode 100644 index 00000000..edc9309b --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp @@ -0,0 +1,116 @@ +#include "duckdb/common/tree_renderer/json_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "utf8proc_wrapper.hpp" + +#include "yyjson.hpp" + +#include + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +string JSONTreeRenderer::ToString(const LogicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string JSONTreeRenderer::ToString(const PhysicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string JSONTreeRenderer::ToString(const ProfilingNode &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string JSONTreeRenderer::ToString(const Pipeline &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void JSONTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void JSONTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void JSONTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void JSONTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +static yyjson_mut_val *RenderRecursive(yyjson_mut_doc *doc, RenderTree &tree, idx_t x, idx_t y) { + auto node_p = tree.GetNode(x, y); + D_ASSERT(node_p); + auto &node = *node_p; + + auto object = yyjson_mut_obj(doc); + auto children = yyjson_mut_arr(doc); + for (auto &child_pos : node.child_positions) { + auto child_object = RenderRecursive(doc, tree, child_pos.x, child_pos.y); + yyjson_mut_arr_append(children, child_object); + } + yyjson_mut_obj_add_str(doc, object, "name", node.name.c_str()); + yyjson_mut_obj_add_val(doc, object, "children", children); + auto extra_info = yyjson_mut_obj(doc); + for (auto &it : node.extra_text) { + auto &key = it.first; + auto &value = it.second; + auto splits = StringUtil::Split(value, "\n"); + if (splits.size() > 1) { + auto list_items = yyjson_mut_arr(doc); + for (auto &split : splits) { + yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); + } + yyjson_mut_obj_add_val(doc, extra_info, key.c_str(), list_items); + } else { + yyjson_mut_obj_add_strcpy(doc, extra_info, key.c_str(), value.c_str()); + } + } + yyjson_mut_obj_add_val(doc, object, "extra_info", extra_info); + return object; +} + +void JSONTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + auto doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_arr(doc); + yyjson_mut_doc_set_root(doc, result_obj); + + auto plan = RenderRecursive(doc, root, 0, 0); + yyjson_mut_arr_append(result_obj, plan); + + auto data = yyjson_mut_val_write_opts(result_obj, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, + nullptr, nullptr); + if (!data) { + yyjson_mut_doc_free(doc); + throw InternalException("The plan could not be rendered as JSON, yyjson failed"); + } + ss << string(data); + free(data); + yyjson_mut_doc_free(doc); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp new file mode 100644 index 00000000..8e0fa425 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp @@ -0,0 +1,482 @@ +#include "duckdb/common/tree_renderer/text_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "utf8proc_wrapper.hpp" +#include "duckdb/common/typedefs.hpp" + +#include + +namespace duckdb { + +namespace { + +struct StringSegment { +public: + StringSegment(idx_t start, idx_t width) : start(start), width(width) { + } + +public: + idx_t start; + idx_t width; +}; + +} // namespace + +void TextTreeRenderer::RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y) { + for (idx_t x = 0; x < root.width; x++) { + if (x * config.node_render_width >= config.maximum_render_width) { + break; + } + if (root.HasNode(x, y)) { + ss << config.LTCORNER; + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); + if (y == 0) { + // top level node: no node above this one + ss << config.HORIZONTAL; + } else { + // render connection to node above this one + ss << config.DMIDDLE; + } + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); + ss << config.RTCORNER; + } else { + bool has_adjacent_nodes = false; + for (idx_t i = 0; x + i < root.width; i++) { + has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); + } + if (!has_adjacent_nodes) { + // There are no nodes to the right side of this position + // no need to fill the empty space + continue; + } + // there are nodes next to this, fill the space + ss << StringUtil::Repeat(" ", config.node_render_width); + } + } + ss << '\n'; +} + +static bool NodeHasMultipleChildren(RenderTreeNode &node) { + return node.child_positions.size() > 1; +} + +static bool ShouldRenderWhitespace(RenderTree &root, idx_t x, idx_t y) { + idx_t found_children = 0; + for (;; x--) { + auto node = root.GetNode(x, y); + if (root.HasNode(x, y + 1)) { + found_children++; + } + if (node) { + if (NodeHasMultipleChildren(*node)) { + if (found_children < node->child_positions.size()) { + return true; + } + } + return false; + } + if (x == 0) { + break; + } + } + return false; +} + +void TextTreeRenderer::RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y) { + for (idx_t x = 0; x <= root.width; x++) { + if (x * config.node_render_width >= config.maximum_render_width) { + break; + } + bool has_adjacent_nodes = false; + for (idx_t i = 0; x + i < root.width; i++) { + has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); + } + auto node = root.GetNode(x, y); + if (node) { + ss << config.LDCORNER; + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); + if (root.HasNode(x, y + 1)) { + // node below this one: connect to that one + ss << config.TMIDDLE; + } else { + // no node below this one: end the box + ss << config.HORIZONTAL; + } + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); + ss << config.RDCORNER; + } else if (root.HasNode(x, y + 1)) { + ss << StringUtil::Repeat(" ", config.node_render_width / 2); + ss << config.VERTICAL; + if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { + ss << StringUtil::Repeat(" ", config.node_render_width / 2); + } + } else { + if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { + ss << StringUtil::Repeat(" ", config.node_render_width); + } + } + } + ss << '\n'; +} + +string AdjustTextForRendering(string source, idx_t max_render_width) { + const idx_t size = source.size(); + const char *input = source.c_str(); + + idx_t render_width = 0; + + // For every character in the input, create a StringSegment + vector render_widths; + idx_t current_position = 0; + while (current_position < size) { + idx_t char_render_width = Utf8Proc::RenderWidth(input, size, current_position); + current_position = Utf8Proc::NextGraphemeCluster(input, size, current_position); + render_width += char_render_width; + render_widths.push_back(StringSegment(current_position, render_width)); + if (render_width > max_render_width) { + break; + } + } + + if (render_width > max_render_width) { + // need to find a position to truncate + for (idx_t pos = render_widths.size(); pos > 0; pos--) { + auto &source_range = render_widths[pos - 1]; + if (source_range.width < max_render_width - 4) { + return source.substr(0, source_range.start) + string("...") + + string(max_render_width - source_range.width - 3, ' '); + } + } + source = "..."; + } + // need to pad with spaces + idx_t total_spaces = max_render_width - render_width; + idx_t half_spaces = total_spaces / 2; + idx_t extra_left_space = total_spaces % 2 == 0 ? 0 : 1; + return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); +} + +void TextTreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { + // we first need to figure out how high our boxes are going to be + vector> extra_info; + idx_t extra_height = 0; + extra_info.resize(root.width); + for (idx_t x = 0; x < root.width; x++) { + auto node = root.GetNode(x, y); + if (node) { + SplitUpExtraInfo(node->extra_text, extra_info[x]); + if (extra_info[x].size() > extra_height) { + extra_height = extra_info[x].size(); + } + } + } + extra_height = MinValue(extra_height, config.max_extra_lines); + idx_t halfway_point = (extra_height + 1) / 2; + // now we render the actual node + for (idx_t render_y = 0; render_y <= extra_height; render_y++) { + for (idx_t x = 0; x < root.width; x++) { + if (x * config.node_render_width >= config.maximum_render_width) { + break; + } + bool has_adjacent_nodes = false; + for (idx_t i = 0; x + i < root.width; i++) { + has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); + } + auto node = root.GetNode(x, y); + if (!node) { + if (render_y == halfway_point) { + bool has_child_to_the_right = ShouldRenderWhitespace(root, x, y); + if (root.HasNode(x, y + 1)) { + // node right below this one + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); + if (has_child_to_the_right) { + ss << config.TMIDDLE; + // but we have another child to the right! keep rendering the line + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); + } else { + ss << config.RTCORNER; + if (has_adjacent_nodes) { + // only a child below this one: fill the rest with spaces + ss << StringUtil::Repeat(" ", config.node_render_width / 2); + } + } + } else if (has_child_to_the_right) { + // child to the right, but no child right below this one: render a full line + ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width); + } else { + if (has_adjacent_nodes) { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.node_render_width); + } + } + } else if (render_y >= halfway_point) { + if (root.HasNode(x, y + 1)) { + // we have a node below this empty spot: render a vertical line + ss << StringUtil::Repeat(" ", config.node_render_width / 2); + ss << config.VERTICAL; + if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { + ss << StringUtil::Repeat(" ", config.node_render_width / 2); + } + } else { + if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.node_render_width); + } + } + } else { + if (has_adjacent_nodes) { + // empty spot: render spaces + ss << StringUtil::Repeat(" ", config.node_render_width); + } + } + } else { + ss << config.VERTICAL; + // figure out what to render + string render_text; + if (render_y == 0) { + render_text = node->name; + } else { + if (render_y <= extra_info[x].size()) { + render_text = extra_info[x][render_y - 1]; + } + } + if (render_y + 1 == extra_height && render_text.empty()) { + auto entry = node->extra_text.find(RenderTreeNode::CARDINALITY); + if (entry != node->extra_text.end()) { + render_text = entry->second + " Rows"; + } + } + if (render_y == extra_height && render_text.empty()) { + auto timing_entry = node->extra_text.find(RenderTreeNode::TIMING); + if (timing_entry != node->extra_text.end()) { + render_text = "(" + timing_entry->second + ")"; + } else if (node->extra_text.find(RenderTreeNode::CARDINALITY) == node->extra_text.end()) { + // we only render estimated cardinality if there is no real cardinality + auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); + if (entry != node->extra_text.end()) { + render_text = "~" + entry->second + " Rows"; + } + } + if (node->extra_text.find(RenderTreeNode::CARDINALITY) == node->extra_text.end()) { + // we only render estimated cardinality if there is no real cardinality + auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); + if (entry != node->extra_text.end()) { + render_text = "~" + entry->second + " Rows"; + } + } + } + render_text = AdjustTextForRendering(render_text, config.node_render_width - 2); + ss << render_text; + + if (render_y == halfway_point && NodeHasMultipleChildren(*node)) { + ss << config.LMIDDLE; + } else { + ss << config.VERTICAL; + } + } + } + ss << '\n'; + } +} + +string TextTreeRenderer::ToString(const LogicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TextTreeRenderer::ToString(const PhysicalOperator &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TextTreeRenderer::ToString(const ProfilingNode &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string TextTreeRenderer::ToString(const Pipeline &op) { + std::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void TextTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void TextTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void TextTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void TextTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void TextTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + while (root.width * config.node_render_width > config.maximum_render_width) { + if (config.node_render_width - 2 < config.minimum_render_width) { + break; + } + config.node_render_width -= 2; + } + + for (idx_t y = 0; y < root.height; y++) { + // start by rendering the top layer + RenderTopLayer(root, ss, y); + // now we render the content of the boxes + RenderBoxContent(root, ss, y); + // render the bottom layer of each of the boxes + RenderBottomLayer(root, ss, y); + } +} + +bool TextTreeRenderer::CanSplitOnThisChar(char l) { + return (l < '0' || (l > '9' && l < 'A') || (l > 'Z' && l < 'a')) && l != '_'; +} + +bool TextTreeRenderer::IsPadding(char l) { + return l == ' ' || l == '\t' || l == '\n' || l == '\r'; +} + +string TextTreeRenderer::RemovePadding(string l) { + idx_t start = 0, end = l.size(); + while (start < l.size() && IsPadding(l[start])) { + start++; + } + while (end > 0 && IsPadding(l[end - 1])) { + end--; + } + return l.substr(start, end - start); +} + +void TextTreeRenderer::SplitStringBuffer(const string &source, vector &result) { + D_ASSERT(Utf8Proc::IsValid(source.c_str(), source.size())); + const idx_t max_line_render_size = config.node_render_width - 2; + // utf8 in prompt, get render width + idx_t character_pos = 0; + idx_t start_pos = 0; + idx_t render_width = 0; + idx_t last_possible_split = 0; + + const idx_t size = source.size(); + const char *input = source.c_str(); + + while (character_pos < size) { + size_t char_render_width = Utf8Proc::RenderWidth(input, size, character_pos); + idx_t next_character_pos = Utf8Proc::NextGraphemeCluster(input, size, character_pos); + + // Does the next character make us exceed the line length? + if (render_width + char_render_width > max_line_render_size) { + if (start_pos + 8 > last_possible_split) { + // The last character we can split on is one of the first 8 characters of the line + // to not create very small lines we instead split on the current character + last_possible_split = character_pos; + } + result.push_back(source.substr(start_pos, last_possible_split - start_pos)); + render_width = character_pos - last_possible_split; + start_pos = last_possible_split; + character_pos = last_possible_split; + } + // check if we can split on this character + if (CanSplitOnThisChar(source[character_pos])) { + last_possible_split = character_pos; + } + character_pos = next_character_pos; + render_width += char_render_width; + } + if (size > start_pos) { + // append the remainder of the input + result.push_back(source.substr(start_pos, size - start_pos)); + } +} + +void TextTreeRenderer::SplitUpExtraInfo(const InsertionOrderPreservingMap &extra_info, vector &result) { + if (extra_info.empty()) { + return; + } + for (auto &item : extra_info) { + auto &text = item.second; + if (!Utf8Proc::IsValid(text.c_str(), text.size())) { + return; + } + } + result.push_back(ExtraInfoSeparator()); + + bool requires_padding = false; + bool was_inlined = false; + for (auto &item : extra_info) { + string str = RemovePadding(item.second); + if (str.empty()) { + continue; + } + bool is_inlined = false; + if (!StringUtil::StartsWith(item.first, "__")) { + // the name is not internal (i.e. not __text__) - so we display the name in addition to the entry + const idx_t available_width = (config.node_render_width - 7); + idx_t total_size = item.first.size() + str.size() + 2; + bool is_multiline = StringUtil::Contains(str, "\n"); + if (!is_multiline && total_size < available_width) { + // we can inline the full entry - no need for any separators unless the previous entry explicitly + // requires it + str = item.first + ": " + str; + is_inlined = true; + } else { + str = item.first + ":\n" + str; + } + } + if (is_inlined && was_inlined) { + // we can skip the padding if we have multiple inlined entries in a row + requires_padding = false; + } + if (requires_padding) { + result.emplace_back(); + } + // cardinality, timing and estimated cardinality are rendered separately + // this is to allow alignment horizontally across nodes + if (item.first == RenderTreeNode::CARDINALITY) { + // cardinality - need to reserve space for cardinality AND timing + result.emplace_back(); + if (extra_info.find(RenderTreeNode::TIMING) != extra_info.end()) { + result.emplace_back(); + } + break; + } + if (item.first == RenderTreeNode::ESTIMATED_CARDINALITY) { + // estimated cardinality - reserve space for estimate + if (extra_info.find(RenderTreeNode::CARDINALITY) != extra_info.end()) { + // if we have a true cardinality render that instead of the estimate + result.pop_back(); + continue; + } + result.emplace_back(); + break; + } + auto splits = StringUtil::Split(str, "\n"); + for (auto &split : splits) { + SplitStringBuffer(split, result); + } + requires_padding = true; + was_inlined = is_inlined; + } +} + +string TextTreeRenderer::ExtraInfoSeparator() { + return StringUtil::Repeat(string(config.HORIZONTAL), (config.node_render_width - 9)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/tree_renderer.cpp new file mode 100644 index 00000000..b0a58b59 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/tree_renderer.cpp @@ -0,0 +1,12 @@ +#include "duckdb/common/tree_renderer.hpp" + +namespace duckdb { + +void TreeRenderer::ToStream(RenderTree &root, std::ostream &ss) { + if (!UsesRawKeyNames()) { + root.SanitizeKeyNames(); + } + return ToStreamInternal(root, ss); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 18b594ef..a440b496 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -29,7 +29,6 @@ #include "duckdb/parser/keyword_helper.hpp" #include "duckdb/parser/parser.hpp" #include "duckdb/main/config.hpp" - #include namespace duckdb { @@ -117,6 +116,7 @@ PhysicalType LogicalType::GetInternalType() { case LogicalTypeId::CHAR: case LogicalTypeId::BLOB: case LogicalTypeId::BIT: + case LogicalTypeId::VARINT: return PhysicalType::VARCHAR; case LogicalTypeId::INTERVAL: return PhysicalType::INTERVAL; @@ -203,6 +203,8 @@ constexpr const LogicalTypeId LogicalType::VARCHAR; constexpr const LogicalTypeId LogicalType::BLOB; constexpr const LogicalTypeId LogicalType::BIT; +constexpr const LogicalTypeId LogicalType::VARINT; + constexpr const LogicalTypeId LogicalType::INTERVAL; constexpr const LogicalTypeId LogicalType::ROW_TYPE; @@ -236,14 +238,14 @@ const vector LogicalType::Real() { const vector LogicalType::AllTypes() { vector types = { - LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, - LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, - LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, LogicalType::UHUGEINT, - LogicalType::TIME, LogicalTypeId::LIST, LogicalTypeId::STRUCT, LogicalType::TIME_TZ, - LogicalType::TIMESTAMP_TZ, LogicalTypeId::MAP, LogicalTypeId::UNION, LogicalType::UUID, - LogicalTypeId::ARRAY}; + LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, + LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, + LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, + LogicalType::VARINT, LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, + LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, + LogicalType::UHUGEINT, LogicalType::TIME, LogicalTypeId::LIST, LogicalTypeId::STRUCT, + LogicalType::TIME_TZ, LogicalType::TIMESTAMP_TZ, LogicalTypeId::MAP, LogicalTypeId::UNION, + LogicalType::UUID, LogicalTypeId::ARRAY}; return types; } @@ -524,7 +526,67 @@ LogicalType TransformStringToLogicalType(const string &str) { if (StringUtil::Lower(str) == "null") { return LogicalType::SQLNULL; } - return Parser::ParseColumnList("dummy " + str).GetColumn(LogicalIndex(0)).Type(); + ColumnList column_list; + try { + column_list = Parser::ParseColumnList("dummy " + str); + } catch (const std::runtime_error &e) { + const vector suggested_types {"BIGINT", + "INT8", + "LONG", + "BIT", + "BITSTRING", + "BLOB", + "BYTEA", + "BINARY,", + "VARBINARY", + "BOOLEAN", + "BOOL", + "LOGICAL", + "DATE", + "DECIMAL(prec, scale)", + "DOUBLE", + "FLOAT8", + "FLOAT", + "FLOAT4", + "REAL", + "HUGEINT", + "INTEGER", + "INT4", + "INT", + "SIGNED", + "INTERVAL", + "SMALLINT", + "INT2", + "SHORT", + "TIME", + "TIMESTAMPTZ ", + "TIMESTAMP", + "DATETIME", + "TINYINT", + "INT1", + "UBIGINT", + "UHUGEINT", + "UINTEGER", + "USMALLINT", + "UTINYINT", + "UUID", + "VARCHAR", + "CHAR", + "BPCHAR", + "TEXT", + "STRING", + "MAP(INTEGER, VARCHAR)", + "UNION(num INTEGER, text VARCHAR)"}; + std::ostringstream error; + error << "Value \"" << str << "\" can not be converted to a DuckDB Type." << '\n'; + error << "Possible examples as suggestions: " << '\n'; + auto suggestions = StringUtil::TopNJaroWinkler(suggested_types, str); + for (auto &suggestion : suggestions) { + error << "* " << suggestion << '\n'; + } + throw InvalidInputException(error.str()); + } + return column_list.GetColumn(LogicalIndex(0)).Type(); } LogicalType GetUserTypeRecursive(const LogicalType &type, ClientContext &context) { @@ -596,12 +658,24 @@ bool LogicalType::IsNumeric() const { } } -bool LogicalType::IsValid() const { - return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; +bool LogicalType::IsTemporal() const { + switch (id_) { + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + return true; + default: + return false; + } } -bool LogicalType::Contains(LogicalTypeId type_id) const { - return Contains([&](const LogicalType &type) { return type.id() == type_id; }); +bool LogicalType::IsValid() const { + return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; } bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const { @@ -752,6 +826,8 @@ LogicalType LogicalType::NormalizeType(const LogicalType &type) { return LogicalType::VARCHAR; case LogicalTypeId::INTEGER_LITERAL: return IntegerLiteral::GetType(type); + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); default: return type; } @@ -767,7 +843,7 @@ static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &righ return OP::Operation(left, LogicalType::VARCHAR, result); } // NULL/string literals/unknown (parameter) types always take the other type - LogicalTypeId other_types[] = {LogicalTypeId::UNKNOWN, LogicalTypeId::SQLNULL, LogicalTypeId::STRING_LITERAL}; + LogicalTypeId other_types[] = {LogicalTypeId::SQLNULL, LogicalTypeId::UNKNOWN, LogicalTypeId::STRING_LITERAL}; for (auto &other_type : other_types) { if (left.id() == other_type) { result = LogicalType::NormalizeType(right); @@ -1048,6 +1124,8 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 101; case LogicalTypeId::UUID: return 102; + case LogicalTypeId::VARINT: + return 103; // nested types case LogicalTypeId::STRUCT: return 125; @@ -1370,7 +1448,7 @@ bool StructType::IsUnnamed(const LogicalType &type) { if (child_types.empty()) { return false; } - return child_types[0].first.empty(); + return child_types[0].first.empty(); // NOLINT } LogicalType LogicalType::STRUCT(child_list_t children) { @@ -1624,16 +1702,18 @@ LogicalType ArrayType::ConvertToList(const LogicalType &type) { } } -LogicalType LogicalType::ARRAY(const LogicalType &child, idx_t size) { - D_ASSERT(size > 0); - D_ASSERT(size <= ArrayType::MAX_ARRAY_SIZE); - auto info = make_shared_ptr(child, size); - return LogicalType(LogicalTypeId::ARRAY, std::move(info)); -} - -LogicalType LogicalType::ARRAY(const LogicalType &child) { - auto info = make_shared_ptr(child, 0); - return LogicalType(LogicalTypeId::ARRAY, std::move(info)); +LogicalType LogicalType::ARRAY(const LogicalType &child, optional_idx size) { + if (!size.IsValid()) { + // Create an incomplete ARRAY type, used for binding + auto info = make_shared_ptr(child, 0); + return LogicalType(LogicalTypeId::ARRAY, std::move(info)); + } else { + auto array_size = size.GetIndex(); + D_ASSERT(array_size > 0); + D_ASSERT(array_size <= ArrayType::MAX_ARRAY_SIZE); + auto info = make_shared_ptr(child, array_size); + return LogicalType(LogicalTypeId::ARRAY, std::move(info)); + } } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/batched_data_collection.cpp b/src/duckdb/src/common/types/batched_data_collection.cpp index 072e0c73..fd25dbc1 100644 --- a/src/duckdb/src/common/types/batched_data_collection.cpp +++ b/src/duckdb/src/common/types/batched_data_collection.cpp @@ -11,6 +11,11 @@ BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, batch_map_t batches, + bool buffer_managed_p) + : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p), data(std::move(batches)) { +} + void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { D_ASSERT(batch_index != DConstants::INVALID_INDEX); optional_ptr collection; @@ -50,28 +55,34 @@ void BatchedDataCollection::Merge(BatchedDataCollection &other) { other.data.clear(); } -void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state) { - state.iterator = data.begin(); - if (state.iterator == data.end()) { +void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state, const BatchedChunkIteratorRange &range) { + state.range = range; + if (state.range.begin == state.range.end) { return; } - state.iterator->second->InitializeScan(state.scan_state); + state.range.begin->second->InitializeScan(state.scan_state); +} + +void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state) { + auto range = BatchRange(); + return InitializeScan(state, range); } void BatchedDataCollection::Scan(BatchedChunkScanState &state, DataChunk &output) { - while (state.iterator != data.end()) { + while (state.range.begin != state.range.end) { // check if there is a chunk remaining in this collection - auto collection = state.iterator->second.get(); + auto collection = state.range.begin->second.get(); collection->Scan(state.scan_state, output); if (output.size() > 0) { return; } // there isn't! move to the next collection - state.iterator++; - if (state.iterator == data.end()) { + state.range.begin->second.reset(); + state.range.begin++; + if (state.range.begin == state.range.end) { return; } - state.iterator->second->InitializeScan(state.scan_state); + state.range.begin->second->InitializeScan(state.scan_state); } } @@ -92,6 +103,64 @@ unique_ptr BatchedDataCollection::FetchCollection() { return result; } +const vector &BatchedDataCollection::Types() const { + return types; +} + +idx_t BatchedDataCollection::Count() const { + idx_t count = 0; + for (auto &collection : data) { + count += collection.second->Count(); + } + return count; +} + +idx_t BatchedDataCollection::BatchCount() const { + return data.size(); +} + +idx_t BatchedDataCollection::IndexToBatchIndex(idx_t index) const { + if (index >= data.size()) { + throw InternalException("Index %d is out of range for this collection, it only contains %d batches", index, + data.size()); + } + auto entry = data.begin(); + std::advance(entry, index); + return entry->first; +} + +idx_t BatchedDataCollection::BatchSize(idx_t batch_index) const { + auto &collection = Batch(batch_index); + return collection.Count(); +} + +const ColumnDataCollection &BatchedDataCollection::Batch(idx_t batch_index) const { + auto entry = data.find(batch_index); + if (entry == data.end()) { + throw InternalException("This batched data collection does not contain a collection for batch_index %d", + batch_index); + } + return *entry->second; +} + +BatchedChunkIteratorRange BatchedDataCollection::BatchRange(idx_t begin_idx, idx_t end_idx) { + D_ASSERT(begin_idx < end_idx); + if (end_idx > data.size()) { + // Limit the iterator to the end + end_idx = DConstants::INVALID_INDEX; + } + BatchedChunkIteratorRange range; + range.begin = data.begin(); + std::advance(range.begin, begin_idx); + if (end_idx == DConstants::INVALID_INDEX) { + range.end = data.end(); + } else { + range.end = data.begin(); + std::advance(range.end, end_idx); + } + return range; +} + string BatchedDataCollection::ToString() const { string result; result += "Batched Data Collection\n"; diff --git a/src/duckdb/src/common/types/bit.cpp b/src/duckdb/src/common/types/bit.cpp index 83f0ce83..ebd4ec3a 100644 --- a/src/duckdb/src/common/types/bit.cpp +++ b/src/duckdb/src/common/types/bit.cpp @@ -9,7 +9,7 @@ namespace duckdb { // **** helper functions **** static char ComputePadding(idx_t len) { - return (8 - (len % 8)) % 8; + return UnsafeNumericCast((8 - (len % 8)) % 8); } idx_t Bit::ComputeBitstringLen(idx_t len) { @@ -89,7 +89,7 @@ void Bit::ToString(string_t bits, char *output) { string Bit::ToString(string_t str) { auto len = BitLength(str); - auto buffer = make_unsafe_uniq_array(len); + auto buffer = make_unsafe_uniq_array_uninitialized(len); ToString(str, buffer.get()); return string(buffer.get(), len); } @@ -150,7 +150,7 @@ void Bit::ToBit(string_t str, string_t &output_str) { string Bit::ToBit(string_t str) { auto bit_len = GetBitSize(str); - auto buffer = make_unsafe_uniq_array(bit_len); + auto buffer = make_unsafe_uniq_array_uninitialized(bit_len); string_t output_str(buffer.get(), UnsafeNumericCast(bit_len)); Bit::ToBit(str, output_str); return output_str.GetString(); @@ -166,7 +166,7 @@ void Bit::BlobToBit(string_t blob, string_t &output_str) { } string Bit::BlobToBit(string_t blob) { - auto buffer = make_unsafe_uniq_array(blob.GetSize() + 1); + auto buffer = make_unsafe_uniq_array_uninitialized(blob.GetSize() + 1); string_t output_str(buffer.get(), UnsafeNumericCast(blob.GetSize() + 1)); Bit::BlobToBit(blob, output_str); return output_str.GetString(); @@ -192,7 +192,7 @@ void Bit::BitToBlob(string_t bit, string_t &output_blob) { string Bit::BitToBlob(string_t bit) { D_ASSERT(bit.GetSize() > 1); - auto buffer = make_unsafe_uniq_array(bit.GetSize() - 1); + auto buffer = make_unsafe_uniq_array_uninitialized(bit.GetSize() - 1); string_t output_str(buffer.get(), UnsafeNumericCast(bit.GetSize() - 1)); Bit::BitToBlob(bit, output_str); return output_str.GetString(); @@ -281,7 +281,7 @@ idx_t Bit::GetBitInternal(string_t bit_string, idx_t n) { const char *buf = bit_string.GetData(); auto idx = Bit::GetBitIndex(n); D_ASSERT(idx < bit_string.GetSize()); - char byte = buf[idx] >> (7 - (n % 8)); + auto byte = buf[idx] >> (7 - (n % 8)); return (byte & 1 ? 1 : 0); } @@ -291,7 +291,7 @@ void Bit::SetBit(string_t &bit_string, idx_t n, idx_t new_value) { } void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) { - auto buf = bit_string.GetDataWriteable(); + uint8_t *buf = reinterpret_cast(bit_string.GetDataWriteable()); auto idx = Bit::GetBitIndex(n); D_ASSERT(idx < bit_string.GetSize()); @@ -306,8 +306,9 @@ void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) { // **** BITWISE operators **** void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = bit_string.GetData(); + uint8_t *res_buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *buf = reinterpret_cast(bit_string.GetData()); + res_buf[0] = buf[0]; for (idx_t i = 0; i < Bit::BitLength(result); i++) { if (i < shift) { @@ -321,8 +322,9 @@ void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &r } void Bit::LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = bit_string.GetData(); + uint8_t *res_buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *buf = reinterpret_cast(bit_string.GetData()); + res_buf[0] = buf[0]; for (idx_t i = 0; i < Bit::BitLength(bit_string); i++) { if (i < (Bit::BitLength(bit_string) - shift)) { @@ -340,9 +342,9 @@ void Bit::BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result) throw InvalidInputException("Cannot AND bit strings of different sizes"); } - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); + uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); + const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); buf[0] = l_buf[0]; for (idx_t i = 1; i < lhs.GetSize(); i++) { @@ -356,9 +358,9 @@ void Bit::BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result) throw InvalidInputException("Cannot OR bit strings of different sizes"); } - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); + uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); + const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); buf[0] = l_buf[0]; for (idx_t i = 1; i < lhs.GetSize(); i++) { @@ -372,9 +374,9 @@ void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result) throw InvalidInputException("Cannot XOR bit strings of different sizes"); } - char *buf = result.GetDataWriteable(); - const char *r_buf = rhs.GetData(); - const char *l_buf = lhs.GetData(); + uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); + const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); buf[0] = l_buf[0]; for (idx_t i = 1; i < lhs.GetSize(); i++) { @@ -384,8 +386,8 @@ void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result) } void Bit::BitwiseNot(const string_t &input, string_t &result) { - char *result_buf = result.GetDataWriteable(); - const char *buf = input.GetData(); + uint8_t *result_buf = reinterpret_cast(result.GetDataWriteable()); + const uint8_t *buf = reinterpret_cast(input.GetData()); result_buf[0] = buf[0]; for (idx_t i = 1; i < input.GetSize(); i++) { diff --git a/src/duckdb/src/common/types/blob.cpp b/src/duckdb/src/common/types/blob.cpp index 6f472c0e..11cd47b7 100644 --- a/src/duckdb/src/common/types/blob.cpp +++ b/src/duckdb/src/common/types/blob.cpp @@ -1,10 +1,11 @@ -#include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/blob.hpp" + #include "duckdb/common/assert.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/string_type.hpp" namespace duckdb { @@ -65,7 +66,7 @@ void Blob::ToString(string_t blob, char *output) { string Blob::ToString(string_t blob) { auto str_len = GetStringSize(blob); - auto buffer = make_unsafe_uniq_array(str_len); + auto buffer = make_unsafe_uniq_array_uninitialized(str_len); Blob::ToString(blob, buffer.get()); return string(buffer.get(), str_len); } @@ -77,15 +78,16 @@ bool Blob::TryGetBlobSize(string_t str, idx_t &str_len, CastParameters ¶mete for (idx_t i = 0; i < len; i++) { if (data[i] == '\\') { if (i + 3 >= len) { - string error = "Invalid hex escape code encountered in string -> blob conversion: " - "unterminated escape code at end of blob"; + string error = StringUtil::Format("Invalid hex escape code encountered in string -> blob conversion of " + "string \"%s\": unterminated escape code at end of blob", + str.GetString()); HandleCastError::AssignError(error, parameters); return false; } if (data[i + 1] != 'x' || Blob::HEX_MAP[data[i + 2]] < 0 || Blob::HEX_MAP[data[i + 3]] < 0) { - string error = - StringUtil::Format("Invalid hex escape code encountered in string -> blob conversion: %s", - string(const_char_ptr_cast(data) + i, 4)); + string error = StringUtil::Format( + "Invalid hex escape code encountered in string -> blob conversion of string \"%s\": %s", + str.GetString(), string(const_char_ptr_cast(data) + i, 4)); HandleCastError::AssignError(error, parameters); return false; } @@ -94,8 +96,10 @@ bool Blob::TryGetBlobSize(string_t str, idx_t &str_len, CastParameters ¶mete } else if (data[i] <= 127) { str_len++; } else { - string error = "Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " - "must be escaped with hex codes (e.g. \\xAA)"; + string error = StringUtil::Format( + "Invalid byte encountered in STRING -> BLOB conversion of string \"%s\". All non-ascii characters " + "must be escaped with hex codes (e.g. \\xAA)", + str.GetString()); HandleCastError::AssignError(error, parameters); return false; } @@ -147,7 +151,7 @@ string Blob::ToBlob(string_t str) { string Blob::ToBlob(string_t str, CastParameters ¶meters) { auto blob_len = GetBlobSize(str, parameters); - auto buffer = make_unsafe_uniq_array(blob_len); + auto buffer = make_unsafe_uniq_array_uninitialized(blob_len); Blob::ToBlob(str, data_ptr_cast(buffer.get())); return string(buffer.get(), blob_len); } diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp index d766545a..0081a0e4 100644 --- a/src/duckdb/src/common/types/column/column_data_allocator.cpp +++ b/src/duckdb/src/common/types/column/column_data_allocator.cpp @@ -61,13 +61,13 @@ BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); - auto block_size = MaxValue(size, Storage::BLOCK_SIZE); + auto max_size = MaxValue(size, GetBufferManager().GetBlockSize()); BlockMetaData data; data.size = 0; - data.capacity = NumericCast(block_size); - auto pin = alloc.buffer_manager->Allocate(MemoryTag::COLUMN_DATA, block_size, false, &data.handle); + data.capacity = NumericCast(max_size); + auto pin = alloc.buffer_manager->Allocate(MemoryTag::COLUMN_DATA, max_size, false, &data.handle); blocks.push_back(std::move(data)); - allocated_size += block_size; + allocated_size += max_size; return pin; } @@ -75,7 +75,7 @@ void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { auto allocation_amount = MaxValue(NextPowerOfTwo(size), 4096); if (!blocks.empty()) { idx_t last_capacity = blocks.back().capacity; - auto next_capacity = MinValue(last_capacity * 2, last_capacity + Storage::BLOCK_SIZE); + auto next_capacity = MinValue(last_capacity * 2, last_capacity + Storage::DEFAULT_BLOCK_SIZE); allocation_amount = MaxValue(next_capacity, allocation_amount); } D_ASSERT(type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); @@ -220,13 +220,22 @@ void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector } } -void ColumnDataAllocator::DeleteBlock(uint32_t block_id) { - blocks[block_id].handle->SetCanDestroy(true); +void ColumnDataAllocator::SetDestroyBufferUponUnpin(uint32_t block_id) { + blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); } Allocator &ColumnDataAllocator::GetAllocator() { - return type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR ? *alloc.allocator - : alloc.buffer_manager->GetBufferAllocator(); + if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { + return *alloc.allocator; + } + return alloc.buffer_manager->GetBufferAllocator(); +} + +BufferManager &ColumnDataAllocator::GetBufferManager() { + if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { + throw InternalException("cannot obtain the buffer manager for in memory allocations"); + } + return *alloc.buffer_manager; } void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, ChunkMetaData &chunk) { diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp index c3752a98..c3740444 100644 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -467,6 +467,7 @@ void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVector auto current_index = meta_data.vector_data_index; idx_t remaining = copy_count; + auto block_size = meta_data.segment.allocator->GetBufferManager().GetBlockSize(); while (remaining > 0) { // how many values fit in the current string vector idx_t vector_remaining = @@ -485,19 +486,18 @@ void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVector if (entry.IsInlined()) { continue; } - if (heap_size + entry.GetSize() > Storage::BLOCK_SIZE) { + if (heap_size + entry.GetSize() > block_size) { break; } heap_size += entry.GetSize(); } if (vector_remaining != 0 && append_count == 0) { - // single string is longer than Storage::BLOCK_SIZE - // we allocate one block at a time for long strings + // The string exceeds Storage::DEFAULT_BLOCK_SIZE, so we allocate one block at a time for long strings. auto source_idx = source_data.sel->get_index(offset + append_count); D_ASSERT(source_data.validity.RowIsValid(source_idx)); D_ASSERT(!source_entries[source_idx].IsInlined()); - D_ASSERT(source_entries[source_idx].GetSize() > Storage::BLOCK_SIZE); + D_ASSERT(source_entries[source_idx].GetSize() > block_size); heap_size += source_entries[source_idx].GetSize(); append_count++; } diff --git a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp index 918680c1..fd2e39ff 100644 --- a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp @@ -24,9 +24,9 @@ VectorDataIndex ColumnDataCollectionSegment::AllocateVectorInternal(const Logica meta_data.count = 0; auto internal_type = type.InternalType(); - auto type_size = ((internal_type == PhysicalType::STRUCT) || (internal_type == PhysicalType::ARRAY)) - ? 0 - : GetTypeIdSize(internal_type); + auto struct_or_array = internal_type == PhysicalType::STRUCT || internal_type == PhysicalType::ARRAY; + auto type_size = struct_or_array ? 0 : GetTypeIdSize(internal_type); + allocator->AllocateData(GetDataSize(type_size) + ValidityMask::STANDARD_MASK_SIZE, meta_data.block_id, meta_data.offset, chunk_state); if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || @@ -77,7 +77,6 @@ VectorDataIndex ColumnDataCollectionSegment::AllocateStringHeap(idx_t size, Chun VectorMetaData meta_data; meta_data.count = 0; - allocator->AllocateData(AlignValue(size), meta_data.block_id, meta_data.offset, &append_state.current_chunk_state); chunk_meta.block_ids.insert(meta_data.block_id); diff --git a/src/duckdb/src/common/types/column/column_data_consumer.cpp b/src/duckdb/src/common/types/column/column_data_consumer.cpp index d9fb4fdb..fa20f1d5 100644 --- a/src/duckdb/src/common/types/column/column_data_consumer.cpp +++ b/src/duckdb/src/common/types/column/column_data_consumer.cpp @@ -88,13 +88,13 @@ void ColumnDataConsumer::ConsumeChunks(idx_t delete_index_start, idx_t delete_in if (prev_allocator != curr_allocator) { // Moved to the next allocator, delete all remaining blocks in the previous one for (uint32_t block_id = prev_min_block_id; block_id < prev_allocator->BlockCount(); block_id++) { - prev_allocator->DeleteBlock(block_id); + prev_allocator->SetDestroyBufferUponUnpin(block_id); } continue; } // Same allocator, see if we can delete blocks for (uint32_t block_id = prev_min_block_id; block_id < curr_min_block_id; block_id++) { - prev_allocator->DeleteBlock(block_id); + prev_allocator->SetDestroyBufferUponUnpin(block_id); } } } diff --git a/src/duckdb/src/common/types/column/partitioned_column_data.cpp b/src/duckdb/src/common/types/column/partitioned_column_data.cpp index d59659d3..78e5b367 100644 --- a/src/duckdb/src/common/types/column/partitioned_column_data.cpp +++ b/src/duckdb/src/common/types/column/partitioned_column_data.cpp @@ -34,6 +34,10 @@ void PartitionedColumnData::InitializeAppendState(PartitionedColumnDataAppendSta InitializeAppendStateInternal(state); } +bool PartitionedColumnData::UseFixedSizeMap() const { + return MaxPartitionIndex() < PartitionedTupleDataAppendState::MAP_THRESHOLD; +} + unique_ptr PartitionedColumnData::CreatePartitionBuffer() const { auto result = make_uniq(); result->Initialize(BufferAllocator::Get(context), types, BufferSize()); @@ -44,59 +48,104 @@ void PartitionedColumnData::Append(PartitionedColumnDataAppendState &state, Data // Compute partition indices and store them in state.partition_indices ComputePartitionIndices(state, input); - // Compute the counts per partition - const auto count = input.size(); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - auto &partition_entries = state.partition_entries; + // Build the selection vector for the partitions + BuildPartitionSel(state, input.size()); + + // Early out: check if everything belongs to a single partition + const auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); + if (partition_index.IsValid()) { + auto &partition = *partitions[partition_index.GetIndex()]; + auto &partition_append_state = *state.partition_append_states[partition_index.GetIndex()]; + partition.Append(partition_append_state, input); + return; + } + + if (UseFixedSizeMap()) { + AppendInternal(state, input); + } else { + AppendInternal(state, input); + } +} + +void PartitionedColumnData::BuildPartitionSel(PartitionedColumnDataAppendState &state, const idx_t append_count) const { + if (UseFixedSizeMap()) { + BuildPartitionSel(state, append_count); + } else { + BuildPartitionSel(state, append_count); + } +} + +template +MAP_TYPE &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &) { + throw InternalException("Unknown MAP_TYPE for PartitionedTupleDataGetMap"); +} + +template <> +fixed_size_map_t &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &state) { + return state.fixed_partition_entries; +} + +template <> +perfect_map_t &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &state) { + return state.partition_entries; +} + +template +void PartitionedColumnData::BuildPartitionSel(PartitionedColumnDataAppendState &state, const idx_t append_count) { + using GETTER = TemplatedMapGetter; + auto &partition_entries = state.GetMap(); partition_entries.clear(); + const auto partition_indices = FlatVector::GetData(state.partition_indices); switch (state.partition_indices.GetVectorType()) { case VectorType::FLAT_VECTOR: - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < append_count; i++) { const auto &partition_index = partition_indices[i]; auto partition_entry = partition_entries.find(partition_index); if (partition_entry == partition_entries.end()) { partition_entries[partition_index] = list_entry_t(0, 1); } else { - partition_entry->second.length++; + GETTER::GetValue(partition_entry).length++; } } break; case VectorType::CONSTANT_VECTOR: - partition_entries[partition_indices[0]] = list_entry_t(0, count); + partition_entries[partition_indices[0]] = list_entry_t(0, append_count); break; default: - throw InternalException("Unexpected VectorType in PartitionedColumnData::Append"); + throw InternalException("Unexpected VectorType in PartitionedTupleData::Append"); } // Early out: check if everything belongs to a single partition if (partition_entries.size() == 1) { - const auto &partition_index = partition_entries.begin()->first; - auto &partition = *partitions[partition_index]; - auto &partition_append_state = *state.partition_append_states[partition_index]; - partition.Append(partition_append_state, input); return; } // Compute offsets from the counts idx_t offset = 0; - for (auto &pc : partition_entries) { - auto &partition_entry = pc.second; + for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { + auto &partition_entry = GETTER::GetValue(it); partition_entry.offset = offset; offset += partition_entry.length; } // Now initialize a single selection vector that acts as a selection vector for every partition - auto &all_partitions_sel = state.partition_sel; - for (idx_t i = 0; i < count; i++) { + auto &partition_sel = state.partition_sel; + for (idx_t i = 0; i < append_count; i++) { const auto &partition_index = partition_indices[i]; auto &partition_offset = partition_entries[partition_index].offset; - all_partitions_sel[partition_offset++] = NumericCast(i); + partition_sel[partition_offset++] = UnsafeNumericCast(i); } +} + +template +void PartitionedColumnData::AppendInternal(PartitionedColumnDataAppendState &state, DataChunk &input) { + using GETTER = TemplatedMapGetter; + const auto &partition_entries = state.GetMap(); // Loop through the partitions to append the new data to the partition buffers, and flush the buffers if necessary SelectionVector partition_sel; - for (auto &pc : partition_entries) { - const auto &partition_index = pc.first; + for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { + const auto &partition_index = GETTER::GetKey(it); // Partition, buffer, and append state for this partition index auto &partition = *partitions[partition_index]; @@ -104,12 +153,12 @@ void PartitionedColumnData::Append(PartitionedColumnDataAppendState &state, Data auto &partition_append_state = *state.partition_append_states[partition_index]; // Length and offset into the selection vector for this chunk, for this partition - const auto &partition_entry = pc.second; + const auto &partition_entry = GETTER::GetValue(it); const auto &partition_length = partition_entry.length; const auto partition_offset = partition_entry.offset - partition_length; // Create a selection vector for this partition using the offset into the single selection vector - partition_sel.Initialize(all_partitions_sel.data() + partition_offset); + partition_sel.Initialize(state.partition_sel.data() + partition_offset); if (partition_length >= HalfBufferSize()) { // Slice the input chunk using the selection vector diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp index 22ecb963..eea02568 100644 --- a/src/duckdb/src/common/types/data_chunk.cpp +++ b/src/duckdb/src/common/types/data_chunk.cpp @@ -37,6 +37,15 @@ void DataChunk::Initialize(ClientContext &context, const vector &ty Initialize(Allocator::Get(context), types, capacity_p); } +idx_t DataChunk::GetAllocationSize() const { + idx_t total_size = 0; + auto cardinality = size(); + for (auto &vec : data) { + total_size += vec.GetAllocationSize(cardinality); + } + return total_size; +} + void DataChunk::Initialize(Allocator &allocator, vector::const_iterator begin, vector::const_iterator end, idx_t capacity_p) { D_ASSERT(data.empty()); // can only be initialized once @@ -134,7 +143,7 @@ void DataChunk::Copy(DataChunk &other, idx_t offset) const { void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, const idx_t offset) const { D_ASSERT(ColumnCount() == other.ColumnCount()); D_ASSERT(other.size() == 0); - D_ASSERT((offset + source_count) <= size()); + D_ASSERT(source_count <= size()); for (idx_t i = 0; i < ColumnCount(); i++) { D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); diff --git a/src/duckdb/src/common/types/date.cpp b/src/duckdb/src/common/types/date.cpp index be334903..7c9f78ca 100644 --- a/src/duckdb/src/common/types/date.cpp +++ b/src/duckdb/src/common/types/date.cpp @@ -250,11 +250,16 @@ bool Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result return pos == len; } // first parse the year + idx_t year_length = 0; for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++) { if (year >= 100000000) { return false; } year = (buf[pos] - '0') + year * 10; + year_length++; + } + if (year_length < 2 && strict) { + return false; } if (yearneg) { year = -year; @@ -362,7 +367,7 @@ string Date::ToString(date_t date) { Date::Convert(date, date_units[0], date_units[1], date_units[2]); auto length = DateToStringCast::Length(date_units, year_length, add_bc); - auto buffer = make_unsafe_uniq_array(length); + auto buffer = make_unsafe_uniq_array_uninitialized(length); DateToStringCast::Format(buffer.get(), date_units, year_length, add_bc); return string(buffer.get(), length); } @@ -450,22 +455,6 @@ int64_t Date::EpochMilliseconds(date_t date) { return result; } -int32_t Date::ExtractYear(date_t d, int32_t *last_year) { - auto n = d.days; - // cached look up: check if year of this date is the same as the last one we looked up - // note that this only works for years in the range [1970, 2370] - if (n >= Date::CUMULATIVE_YEAR_DAYS[*last_year] && n < Date::CUMULATIVE_YEAR_DAYS[*last_year + 1]) { - return Date::EPOCH_YEAR + *last_year; - } - int32_t year; - Date::ExtractYearOffset(n, year, *last_year); - return year; -} - -int32_t Date::ExtractYear(timestamp_t ts, int32_t *last_year) { - return Date::ExtractYear(Timestamp::GetDate(ts), last_year); -} - int32_t Date::ExtractYear(date_t d) { int32_t year, year_offset; Date::ExtractYearOffset(d.days, year, year_offset); @@ -515,10 +504,10 @@ int32_t Date::ExtractISODayOfTheWeek(date_t date) { // 7 = 4 if (date.days < 0) { // negative date: start off at 4 and cycle downwards - return (7 - ((-int64_t(date.days) + 3) % 7)); + return UnsafeNumericCast((7 - ((-int64_t(date.days) + 3) % 7))); } else { // positive date: start off at 4 and cycle upwards - return ((int64_t(date.days) + 3) % 7) + 1; + return UnsafeNumericCast(((int64_t(date.days) + 3) % 7) + 1); } } diff --git a/src/duckdb/src/common/types/decimal.cpp b/src/duckdb/src/common/types/decimal.cpp index d05e4a00..5ecb39a0 100644 --- a/src/duckdb/src/common/types/decimal.cpp +++ b/src/duckdb/src/common/types/decimal.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/types/decimal.hpp" + #include "duckdb/common/types/cast_helpers.hpp" namespace duckdb { @@ -6,7 +7,7 @@ namespace duckdb { template string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(UnsafeNumericCast(len + 1)); + auto data = make_unsafe_uniq_array_uninitialized(UnsafeNumericCast(len + 1)); DecimalToString::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); return string(data.get(), UnsafeNumericCast(len)); } @@ -25,7 +26,7 @@ string Decimal::ToString(int64_t value, uint8_t width, uint8_t scale) { string Decimal::ToString(hugeint_t value, uint8_t width, uint8_t scale) { auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(UnsafeNumericCast(len + 1)); + auto data = make_unsafe_uniq_array_uninitialized(UnsafeNumericCast(len + 1)); DecimalToString::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); return string(data.get(), UnsafeNumericCast(len)); } diff --git a/src/duckdb/src/common/types/hugeint.cpp b/src/duckdb/src/common/types/hugeint.cpp index d83d81ca..b19a03b7 100644 --- a/src/duckdb/src/common/types/hugeint.cpp +++ b/src/duckdb/src/common/types/hugeint.cpp @@ -64,7 +64,7 @@ const hugeint_t Hugeint::POWERS_OF_TEN[] { template <> void Hugeint::NegateInPlace(hugeint_t &input) { - input.lower = NumericLimits::Maximum() - input.lower + 1; + input.lower = NumericLimits::Maximum() - input.lower + 1ull; input.upper = -1 - input.upper + (input.lower == 0); } @@ -77,6 +77,14 @@ bool Hugeint::TryNegate(hugeint_t input, hugeint_t &result) { return true; } +hugeint_t Hugeint::Abs(hugeint_t n) { + if (n < 0) { + return Hugeint::Negate(n); + } else { + return n; + } +} + //===--------------------------------------------------------------------===// // Divide //===--------------------------------------------------------------------===// @@ -646,7 +654,7 @@ bool CastBigintToFloating(hugeint_t input, REAL_T &result) { result = -REAL_T(NumericLimits::Maximum() - input.lower) - 1; break; default: - result = REAL_T(input.lower) + REAL_T(input.upper) * REAL_T(NumericLimits::Maximum()); + result = REAL_T(input.lower) + REAL_T(input.upper) * (REAL_T(NumericLimits::Maximum()) + 1); break; } return true; @@ -864,7 +872,7 @@ hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { } else { D_ASSERT(shift < 128); result.lower = 0; - result.upper = (lower << (shift - 64)) & 0x7FFFFFFFFFFFFFFF; + result.upper = UnsafeNumericCast((lower << (shift - 64)) & 0x7FFFFFFFFFFFFFFF); } return result; } diff --git a/src/duckdb/src/common/types/hyperloglog.cpp b/src/duckdb/src/common/types/hyperloglog.cpp index e662738d..3ccd1f0d 100644 --- a/src/duckdb/src/common/types/hyperloglog.cpp +++ b/src/duckdb/src/common/types/hyperloglog.cpp @@ -1,285 +1,270 @@ #include "duckdb/common/types/hyperloglog.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/limits.hpp" #include "duckdb/common/serializer/deserializer.hpp" - +#include "duckdb/common/serializer/serializer.hpp" #include "hyperloglog.hpp" -namespace duckdb { +#include -HyperLogLog::HyperLogLog() : hll(nullptr) { - hll = duckdb_hll::hll_create(); - // Insert into a dense hll can be vectorized, sparse cannot, so we immediately convert - duckdb_hll::hllSparseToDense(hll); +namespace duckdb_hll { +struct robj; // NOLINT } -HyperLogLog::HyperLogLog(duckdb_hll::robj *hll) : hll(hll) { -} +namespace duckdb { -HyperLogLog::~HyperLogLog() { - duckdb_hll::hll_destroy(hll); +idx_t HyperLogLog::Count() const { + uint32_t c[Q + 2] = {0}; + ExtractCounts(c); + return static_cast(EstimateCardinality(c)); } -void HyperLogLog::Add(data_ptr_t element, idx_t size) { - if (duckdb_hll::hll_add(hll, element, size) == HLL_C_ERR) { - throw InternalException("Could not add to HLL?"); +//! Algorithm 2 +void HyperLogLog::Merge(const HyperLogLog &other) { + for (idx_t i = 0; i < M; ++i) { + Update(i, other.k[i]); } } -idx_t HyperLogLog::Count() const { - // exception from size_t ban - size_t result; - - if (duckdb_hll::hll_count(hll, &result) != HLL_C_OK) { - throw InternalException("Could not count HLL?"); +//! Algorithm 4 +void HyperLogLog::ExtractCounts(uint32_t *c) const { + for (idx_t i = 0; i < M; ++i) { + c[k[i]]++; } - return result; } -unique_ptr HyperLogLog::Merge(HyperLogLog &other) { - duckdb_hll::robj *hlls[2]; - hlls[0] = hll; - hlls[1] = other.hll; - auto new_hll = duckdb_hll::hll_merge(hlls, 2); - if (!new_hll) { - throw InternalException("Could not merge HLLs"); +//! Taken from redis code +static double HLLSigma(double x) { + if (x == 1.) { + return std::numeric_limits::infinity(); } - return unique_ptr(new HyperLogLog(new_hll)); + double z_prime; + double y = 1; + double z = x; + do { + x *= x; + z_prime = z; + z += x * y; + y += y; + } while (z_prime != z); + return z; } -HyperLogLog *HyperLogLog::MergePointer(HyperLogLog &other) { - duckdb_hll::robj *hlls[2]; - hlls[0] = hll; - hlls[1] = other.hll; - auto new_hll = duckdb_hll::hll_merge(hlls, 2); - if (!new_hll) { - throw InternalException("Could not merge HLLs"); +//! Taken from redis code +static double HLLTau(double x) { + if (x == 0. || x == 1.) { + return 0.; } - return new HyperLogLog(new_hll); + double z_prime; + double y = 1.0; + double z = 1 - x; + do { + x = sqrt(x); + z_prime = z; + y *= 0.5; + z -= pow(1 - x, 2) * y; + } while (z_prime != z); + return z / 3; } -unique_ptr HyperLogLog::Merge(HyperLogLog logs[], idx_t count) { - auto hlls_uptr = unique_ptr { - new duckdb_hll::robj *[count] - }; - auto hlls = hlls_uptr.get(); - for (idx_t i = 0; i < count; i++) { - hlls[i] = logs[i].hll; - } - auto new_hll = duckdb_hll::hll_merge(hlls, count); - if (!new_hll) { - throw InternalException("Could not merge HLLs"); +//! Algorithm 6 +int64_t HyperLogLog::EstimateCardinality(uint32_t *c) { + auto z = M * HLLTau((double(M) - c[Q]) / double(M)); + + for (idx_t k = Q; k >= 1; --k) { + z += c[k]; + z *= 0.5; } - return unique_ptr(new HyperLogLog(new_hll)); -} -idx_t HyperLogLog::GetSize() { - return duckdb_hll::get_size(); + z += M * HLLSigma(c[0] / double(M)); + + return llroundl(ALPHA * M * M / z); } -data_ptr_t HyperLogLog::GetPtr() const { - return data_ptr_cast((hll)->ptr); +void HyperLogLog::Update(Vector &input, Vector &hash_vec, const idx_t count) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + UnifiedVectorFormat hdata; + hash_vec.ToUnifiedFormat(count, hdata); + const auto hashes = UnifiedVectorFormat::GetData(hdata); + + if (hash_vec.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (idata.validity.RowIsValid(0)) { + InsertElement(hashes[0]); + } + } else { + D_ASSERT(hash_vec.GetVectorType() == VectorType::FLAT_VECTOR); + if (idata.validity.AllValid()) { + for (idx_t i = 0; i < count; ++i) { + const auto hash = hashes[i]; + InsertElement(hash); + } + } else { + for (idx_t i = 0; i < count; ++i) { + if (idata.validity.RowIsValid(idata.sel->get_index(i))) { + const auto hash = hashes[i]; + InsertElement(hash); + } + } + } + } } -unique_ptr HyperLogLog::Copy() { +unique_ptr HyperLogLog::Copy() const { auto result = make_uniq(); - lock_guard guard(lock); - memcpy(result->GetPtr(), GetPtr(), GetSize()); + memcpy(result->k, this->k, sizeof(k)); D_ASSERT(result->Count() == Count()); return result; } -void HyperLogLog::Serialize(Serializer &serializer) const { - serializer.WriteProperty(100, "type", HLLStorageType::UNCOMPRESSED); - serializer.WriteProperty(101, "data", GetPtr(), GetSize()); -} +class HLLV1 { +public: + HLLV1() { + hll = duckdb_hll::hll_create(); + duckdb_hll::hllSparseToDense(hll); + } -unique_ptr HyperLogLog::Deserialize(Deserializer &deserializer) { - auto result = make_uniq(); - auto storage_type = deserializer.ReadProperty(100, "type"); - switch (storage_type) { - case HLLStorageType::UNCOMPRESSED: - deserializer.ReadProperty(101, "data", result->GetPtr(), GetSize()); - break; - default: - throw SerializationException("Unknown HyperLogLog storage type!"); + ~HLLV1() { + duckdb_hll::hll_destroy(hll); } - return result; -} -//===--------------------------------------------------------------------===// -// Vectorized HLL implementation -//===--------------------------------------------------------------------===// -//! Taken from https://nullprogram.com/blog/2018/07/31/ -template -inline uint64_t TemplatedHash(const T &elem) { - uint64_t x = elem; - x ^= x >> 30; - x *= UINT64_C(0xbf58476d1ce4e5b9); - x ^= x >> 27; - x *= UINT64_C(0x94d049bb133111eb); - x ^= x >> 31; - return x; -} +public: + static idx_t GetSize() { + return duckdb_hll::get_size(); + } -template <> -inline uint64_t TemplatedHash(const hugeint_t &elem) { - return TemplatedHash(Load(const_data_ptr_cast(&elem.upper))) ^ - TemplatedHash(elem.lower); -} + data_ptr_t GetPtr() const { + return data_ptr_cast((hll)->ptr); + } -template <> -inline uint64_t TemplatedHash(const uhugeint_t &elem) { - return TemplatedHash(Load(const_data_ptr_cast(&elem.upper))) ^ - TemplatedHash(elem.lower); -} + void ToNew(HyperLogLog &new_hll) const { + const idx_t mult = duckdb_hll::num_registers() / HyperLogLog::M; + // Old implementation used more registers, so we compress the registers, losing some accuracy + for (idx_t i = 0; i < HyperLogLog::M; i++) { + uint8_t max_old = 0; + for (idx_t j = 0; j < mult; j++) { + D_ASSERT(i * mult + j < duckdb_hll::num_registers()); + max_old = MaxValue(max_old, duckdb_hll::get_register(hll, i * mult + j)); + } + new_hll.Update(i, max_old); + } + } -template -inline void CreateIntegerRecursive(const_data_ptr_t &data, uint64_t &x) { - x ^= (uint64_t)data[rest - 1] << ((rest - 1) * 8); - return CreateIntegerRecursive(data, x); -} + void FromNew(const HyperLogLog &new_hll) { + const auto new_hll_count = new_hll.Count(); + if (new_hll_count == 0) { + return; + } -template <> -inline void CreateIntegerRecursive<1>(const_data_ptr_t &data, uint64_t &x) { - x ^= (uint64_t)data[0]; -} + const idx_t mult = duckdb_hll::num_registers() / HyperLogLog::M; + // When going from less to more registers, we cannot just duplicate the registers, + // as each register in the new HLL is the minimum of 'mult' registers in the old HLL. + // Duplicating will make for VERY large over-estimations. Instead, we do the following: + + // Set the first of every 'mult' registers in the old HLL to the value in the new HLL + // This ensures that we can convert NEW to OLD and back to NEW without loss of information + double avg = 0; + for (idx_t i = 0; i < HyperLogLog::M; i++) { + const auto max_new = MinValue(new_hll.GetRegister(i), duckdb_hll::maximum_zeros()); + duckdb_hll::set_register(hll, i * mult, max_new); + avg += static_cast(max_new); + } + avg /= static_cast(HyperLogLog::M); -inline uint64_t HashOtherSize(const_data_ptr_t &data, const idx_t &len) { - uint64_t x = 0; - switch (len & 7) { - case 7: - CreateIntegerRecursive<7>(data, x); - break; - case 6: - CreateIntegerRecursive<6>(data, x); - break; - case 5: - CreateIntegerRecursive<5>(data, x); - break; - case 4: - CreateIntegerRecursive<4>(data, x); - break; - case 3: - CreateIntegerRecursive<3>(data, x); - break; - case 2: - CreateIntegerRecursive<2>(data, x); - break; - case 1: - CreateIntegerRecursive<1>(data, x); - break; - case 0: - default: - D_ASSERT((len & 7) == 0); - break; - } - return TemplatedHash(x); -} + // Using the average will ALWAYS overestimate, so we reduce it a bit here + if (avg > 10) { + avg *= 0.75; + } else if (avg > 2) { + avg -= 2; + } -template <> -inline uint64_t TemplatedHash(const string_t &elem) { - auto data = const_data_ptr_cast(elem.GetData()); - const auto &len = elem.GetSize(); - uint64_t h = 0; - for (idx_t i = 0; i + sizeof(uint64_t) <= len; i += sizeof(uint64_t)) { - h ^= TemplatedHash(Load(data)); - data += sizeof(uint64_t); - } - switch (len & (sizeof(uint64_t) - 1)) { - case 4: - h ^= TemplatedHash(Load(data)); - break; - case 2: - h ^= TemplatedHash(Load(data)); - break; - case 1: - h ^= TemplatedHash(Load(data)); - break; - default: - h ^= HashOtherSize(data, len); + // Set all other registers to a default value, starting with 0 (the initialization value) + // We optimize the default value in 5 iterations or until OLD count is close to NEW count + double default_val = 0; + for (idx_t opt_idx = 0; opt_idx < 5; opt_idx++) { + if (IsWithinAcceptableRange(new_hll_count, Count())) { + break; + } + + // Delta is half the average, then a quarter, etc. + const double delta = avg / static_cast(1 << (opt_idx + 1)); + if (Count() > new_hll_count) { + default_val = delta > default_val ? 0 : default_val - delta; + } else { + default_val += delta; + } + + // If the default value is, e.g., 3.3, then the first 70% gets value 3, and the rest gets value 4 + const double floor_fraction = 1 - (default_val - floor(default_val)); + for (idx_t i = 0; i < HyperLogLog::M; i++) { + const auto max_new = MinValue(new_hll.GetRegister(i), duckdb_hll::maximum_zeros()); + uint8_t register_value; + if (static_cast(i) / static_cast(HyperLogLog::M) < floor_fraction) { + register_value = ExactNumericCast(floor(default_val)); + } else { + register_value = ExactNumericCast(ceil(default_val)); + } + register_value = MinValue(register_value, max_new); + for (idx_t j = 1; j < mult; j++) { + D_ASSERT(i * mult + j < duckdb_hll::num_registers()); + duckdb_hll::set_register(hll, i * mult + j, register_value); + } + } + } } - return h; -} -template -void TemplatedComputeHashes(UnifiedVectorFormat &vdata, const idx_t &count, uint64_t hashes[]) { - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx)) { - hashes[i] = TemplatedHash(data[idx]); - } else { - hashes[i] = 0; +private: + idx_t Count() const { + size_t result; + if (duckdb_hll::hll_count(hll, &result) != HLL_C_OK) { + throw InternalException("Could not count HLL?"); } + return result; } -} -static void ComputeHashes(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], idx_t count) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::UINT8: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT16: - case PhysicalType::UINT16: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT32: - case PhysicalType::UINT32: - case PhysicalType::FLOAT: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT64: - case PhysicalType::UINT64: - case PhysicalType::DOUBLE: - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::INT128: - case PhysicalType::UINT128: - case PhysicalType::INTERVAL: - static_assert(sizeof(uhugeint_t) == sizeof(interval_t), "ComputeHashes assumes these are the same size!"); - return TemplatedComputeHashes(vdata, count, hashes); - case PhysicalType::VARCHAR: - return TemplatedComputeHashes(vdata, count, hashes); - default: - throw InternalException("Unimplemented type for HyperLogLog::ComputeHashes"); + bool IsWithinAcceptableRange(const idx_t &new_hll_count, const idx_t &old_hll_count) const { + const auto newd = static_cast(new_hll_count); + const auto oldd = static_cast(old_hll_count); + return MaxValue(newd, oldd) / MinValue(newd, oldd) < ACCEPTABLE_Q_ERROR; } -} -//! Taken from https://stackoverflow.com/a/72088344 -static inline uint8_t CountTrailingZeros(uint64_t &x) { - static constexpr const uint64_t DEBRUIJN = 0x03f79d71b4cb0a89; - static constexpr const uint8_t LOOKUP[] = {0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61, - 54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62, - 46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45, - 25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63}; - return LOOKUP[(DEBRUIJN * (x ^ (x - 1))) >> 58]; -} - -static inline void ComputeIndexAndCount(uint64_t &hash, uint8_t &prefix) { - uint64_t index = hash & ((1 << 12) - 1); /* Register index. */ - hash >>= 12; /* Remove bits used to address the register. */ - hash |= ((uint64_t)1 << (64 - 12)); /* Make sure the count will be <= Q+1. */ +private: + static constexpr double ACCEPTABLE_Q_ERROR = 1.2; + duckdb_hll::robj *hll; +}; - prefix = CountTrailingZeros(hash) + 1; /* Add 1 since we count the "00000...1" pattern. */ - hash = index; -} - -void HyperLogLog::ProcessEntries(UnifiedVectorFormat &vdata, const LogicalType &type, uint64_t hashes[], - uint8_t counts[], idx_t count) { - ComputeHashes(vdata, type, hashes, count); - for (idx_t i = 0; i < count; i++) { - ComputeIndexAndCount(hashes[i], counts[i]); +void HyperLogLog::Serialize(Serializer &serializer) const { + if (serializer.ShouldSerialize(3)) { + serializer.WriteProperty(100, "type", HLLStorageType::HLL_V2); + serializer.WriteProperty(101, "data", k, sizeof(k)); + } else { + auto old = make_uniq(); + old->FromNew(*this); + + serializer.WriteProperty(100, "type", HLLStorageType::HLL_V1); + serializer.WriteProperty(101, "data", old->GetPtr(), old->GetSize()); } } -void HyperLogLog::AddToLogs(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[], - HyperLogLog **logs[], const SelectionVector *log_sel) { - AddToLogsInternal(vdata, count, indices, counts, reinterpret_cast(logs), log_sel); -} - -void HyperLogLog::AddToLog(UnifiedVectorFormat &vdata, idx_t count, uint64_t indices[], uint8_t counts[]) { - lock_guard guard(lock); - AddToSingleLogInternal(vdata, count, indices, counts, hll); +unique_ptr HyperLogLog::Deserialize(Deserializer &deserializer) { + auto result = make_uniq(); + auto storage_type = deserializer.ReadProperty(100, "type"); + switch (storage_type) { + case HLLStorageType::HLL_V1: { + auto old = make_uniq(); + deserializer.ReadProperty(101, "data", old->GetPtr(), old->GetSize()); + old->ToNew(*result); + break; + } + case HLLStorageType::HLL_V2: + deserializer.ReadProperty(101, "data", result->k, sizeof(k)); + break; + default: + throw SerializationException("Unknown HyperLogLog storage type!"); + } + return result; } } // namespace duckdb diff --git a/src/duckdb/src/common/types/interval.cpp b/src/duckdb/src/common/types/interval.cpp index 7a31cbfe..719e9ee2 100644 --- a/src/duckdb/src/common/types/interval.cpp +++ b/src/duckdb/src/common/types/interval.cpp @@ -448,7 +448,7 @@ interval_t Interval::GetDifference(timestamp_t timestamp_1, timestamp_t timestam interval_t Interval::FromMicro(int64_t delta_us) { interval_t result; result.months = 0; - result.days = delta_us / Interval::MICROS_PER_DAY; + result.days = UnsafeNumericCast(delta_us / Interval::MICROS_PER_DAY); result.micros = delta_us % Interval::MICROS_PER_DAY; return result; diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp index 3b2cfb7a..8145cf07 100644 --- a/src/duckdb/src/common/types/list_segment.cpp +++ b/src/duckdb/src/common/types/list_segment.cpp @@ -29,6 +29,21 @@ static const T *GetPrimitiveData(const ListSegment *segment) { segment->capacity * sizeof(bool)); } +//===--------------------------------------------------------------------===// +// Strings +//===--------------------------------------------------------------------===// +static idx_t GetStringAllocationSize(uint16_t capacity) { + return AlignValue(sizeof(ListSegment) + (capacity * (sizeof(char)))); +} + +static data_ptr_t AllocateStringData(ArenaAllocator &allocator, uint16_t capacity) { + return allocator.Allocate(GetStringAllocationSize(capacity)); +} + +static char *GetStringData(ListSegment *segment) { + return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment)); +} + //===--------------------------------------------------------------------===// // Lists //===--------------------------------------------------------------------===// @@ -125,7 +140,17 @@ static uint16_t GetCapacityForNewSegment(uint16_t capacity) { template static ListSegment *CreatePrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { // allocate data and set the header - auto segment = (ListSegment *)AllocatePrimitiveData(allocator, capacity); + auto segment = reinterpret_cast(AllocatePrimitiveData(allocator, capacity)); + segment->capacity = capacity; + segment->count = 0; + segment->next = nullptr; + return segment; +} + +static ListSegment *CreateVarcharDataSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, + uint16_t capacity) { + // allocate data and set the header + auto segment = reinterpret_cast(AllocateStringData(allocator, capacity)); segment->capacity = capacity; segment->count = 0; segment->next = nullptr; @@ -190,11 +215,9 @@ static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAlloc // determine segment if (!linked_list.last_segment) { // empty linked list, create the first (and last) segment - auto capacity = ListSegment::INITIAL_CAPACITY; - segment = functions.create_segment(functions, allocator, UnsafeNumericCast(capacity)); + segment = functions.create_segment(functions, allocator, functions.initial_capacity); linked_list.first_segment = segment; linked_list.last_segment = segment; - } else if (linked_list.last_segment->capacity == linked_list.last_segment->count) { // the last segment of the linked list is full, create a new one and append it auto capacity = GetCapacityForNewSegment(linked_list.last_segment->capacity); @@ -245,31 +268,29 @@ static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, Are // set the length of this string auto str_length_data = GetListLengthData(segment); - uint64_t str_length = 0; - - // get the string - string_t str_entry; - if (valid) { - str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - str_length = str_entry.GetSize(); - } // we can reconstruct the offset from the length - Store(str_length, data_ptr_cast(str_length_data + segment->count)); if (!valid) { + Store(0, data_ptr_cast(str_length_data + segment->count)); return; } + auto &str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; + auto str_data = str_entry.GetData(); + idx_t str_size = str_entry.GetSize(); + Store(str_size, data_ptr_cast(str_length_data + segment->count)); // write the characters to the linked list of child segments auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - for (char &c : str_entry.GetString()) { + idx_t current_offset = 0; + while (current_offset < str_size) { auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); - auto data = GetPrimitiveData(child_segment); - data[child_segment->count] = c; - child_segment->count++; - child_segments.total_capacity++; + auto data = GetStringData(child_segment); + idx_t copy_count = MinValue(str_size - current_offset, child_segment->capacity - child_segment->count); + memcpy(data + child_segment->count, str_data + current_offset, copy_count); + current_offset += copy_count; + child_segment->count += copy_count; } - + child_segments.total_capacity += str_size; // store the updated linked list Store(child_segments, data_ptr_cast(GetListChildData(segment))); } @@ -394,42 +415,48 @@ static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const Lis static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - // append all the child chars to one string - string str = ""; - auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); - while (linked_child_list.first_segment) { - auto child_segment = linked_child_list.first_segment; - auto data = GetPrimitiveData(child_segment); - str.append(data, child_segment->count); - linked_child_list.first_segment = child_segment->next; - } - linked_child_list.last_segment = nullptr; - // use length and (reconstructed) offset to get the correct substrings auto aggr_vector_data = FlatVector::GetData(result); auto str_length_data = GetListLengthData(segment); - // get the substrings and write them to the result vector - idx_t offset = 0; + auto null_mask = GetNullMask(segment); + auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); + auto current_segment = linked_child_list.first_segment; + idx_t child_offset = 0; for (idx_t i = 0; i < segment->count; i++) { - if (!null_mask[i]) { - auto str_length = Load(const_data_ptr_cast(str_length_data + i)); - auto substr = str.substr(offset, str_length); - auto str_t = StringVector::AddStringOrBlob(result, substr); - aggr_vector_data[total_count + i] = str_t; - offset += str_length; + if (null_mask[i]) { + // set to null + aggr_vector_validity.SetInvalid(total_count + i); + continue; + } + // read the string + auto &result_str = aggr_vector_data[total_count + i]; + auto str_length = Load(const_data_ptr_cast(str_length_data + i)); + // allocate an empty string for the given size + result_str = StringVector::EmptyString(result, str_length); + auto result_data = result_str.GetDataWriteable(); + // copy over the data + idx_t current_offset = 0; + while (current_offset < str_length) { + if (!current_segment) { + throw InternalException("Insufficient data to read string"); + } + auto child_data = GetStringData(current_segment); + idx_t max_copy = MinValue(str_length - current_offset, current_segment->capacity - child_offset); + memcpy(result_data + current_offset, child_data + child_offset, max_copy); + current_offset += max_copy; + child_offset += max_copy; + if (child_offset >= current_segment->capacity) { + D_ASSERT(child_offset == current_segment->capacity); + current_segment = current_segment->next; + child_offset = 0; + } } + + // finalize the str + result_str.Finalize(); } } @@ -553,9 +580,11 @@ void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType case PhysicalType::BIT: case PhysicalType::BOOL: SegmentPrimitiveFunction(functions); + functions.initial_capacity = 8; break; case PhysicalType::INT8: SegmentPrimitiveFunction(functions); + functions.initial_capacity = 8; break; case PhysicalType::INT16: SegmentPrimitiveFunction(functions); @@ -568,6 +597,7 @@ void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType break; case PhysicalType::UINT8: SegmentPrimitiveFunction(functions); + functions.initial_capacity = 8; break; case PhysicalType::UINT16: SegmentPrimitiveFunction(functions); @@ -598,8 +628,12 @@ void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType functions.write_data = WriteDataToVarcharSegment; functions.read_data = ReadDataFromVarcharSegment; - functions.child_functions.emplace_back(); - SegmentPrimitiveFunction(functions.child_functions.back()); + ListSegmentFunctions child_function; + child_function.create_segment = CreateVarcharDataSegment; + child_function.write_data = nullptr; + child_function.read_data = nullptr; + child_function.initial_capacity = 16; + functions.child_functions.push_back(child_function); break; } case PhysicalType::LIST: { diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp index 870bd3aa..17cd306f 100644 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -13,7 +13,7 @@ PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, Buff } PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) - : type(other.type), buffer_manager(other.buffer_manager), layout(other.layout.Copy()) { + : type(other.type), buffer_manager(other.buffer_manager), layout(other.layout.Copy()), count(0), data_size(0) { } PartitionedTupleData::~PartitionedTupleData() { @@ -50,22 +50,13 @@ void PartitionedTupleData::AppendUnified(PartitionedTupleDataAppendState &state, const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? input.size() : append_count; // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(state, input); + ComputePartitionIndices(state, input, append_sel, actual_append_count); // Build the selection vector for the partitions BuildPartitionSel(state, append_sel, actual_append_count); // Early out: check if everything belongs to a single partition - optional_idx partition_index; - if (UseFixedSizeMap()) { - if (state.fixed_partition_entries.size() == 1) { - partition_index = state.fixed_partition_entries.begin().GetKey(); - } - } else { - if (state.partition_entries.size() == 1) { - partition_index = state.partition_entries.begin()->first; - } - } + const auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); if (partition_index.IsValid()) { auto &partition = *partitions[partition_index.GetIndex()]; auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; @@ -99,17 +90,7 @@ void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, TupleD BuildPartitionSel(state, *FlatVector::IncrementalSelectionVector(), append_count); // Early out: check if everything belongs to a single partition - optional_idx partition_index; - if (UseFixedSizeMap()) { - if (state.fixed_partition_entries.size() == 1) { - partition_index = state.fixed_partition_entries.begin().GetKey(); - } - } else { - if (state.partition_entries.size() == 1) { - partition_index = state.partition_entries.begin()->first; - } - } - + auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); if (partition_index.IsValid()) { auto &partition = *partitions[partition_index.GetIndex()]; auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; @@ -135,68 +116,26 @@ void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, TupleD Verify(); } -// LCOV_EXCL_START -template -struct UnorderedMapGetter { - static inline const typename MAP_TYPE::key_type &GetKey(typename MAP_TYPE::iterator &iterator) { - return iterator->first; - } - - static inline const typename MAP_TYPE::key_type &GetKey(const typename MAP_TYPE::const_iterator &iterator) { - return iterator->first; - } - - static inline typename MAP_TYPE::mapped_type &GetValue(typename MAP_TYPE::iterator &iterator) { - return iterator->second; - } - - static inline const typename MAP_TYPE::mapped_type &GetValue(const typename MAP_TYPE::const_iterator &iterator) { - return iterator->second; - } -}; - -template -struct FixedSizeMapGetter { - static inline const idx_t &GetKey(fixed_size_map_iterator_t &iterator) { - return iterator.GetKey(); - } - - static inline const idx_t &GetKey(const const_fixed_size_map_iterator_t &iterator) { - return iterator.GetKey(); - } - - static inline T &GetValue(fixed_size_map_iterator_t &iterator) { - return iterator.GetValue(); - } - - static inline const T &GetValue(const const_fixed_size_map_iterator_t &iterator) { - return iterator.GetValue(); - } -}; -// LCOV_EXCL_STOP - void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, - const idx_t append_count) { + const idx_t append_count) const { if (UseFixedSizeMap()) { - BuildPartitionSel, FixedSizeMapGetter>( - state, state.fixed_partition_entries, append_sel, append_count); + BuildPartitionSel(state, append_sel, append_count); } else { - BuildPartitionSel, UnorderedMapGetter>>( - state, state.partition_entries, append_sel, append_count); + BuildPartitionSel(state, append_sel, append_count); } } -template -void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, MAP_TYPE &partition_entries, - const SelectionVector &append_sel, const idx_t append_count) { +template +void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, + const idx_t append_count) { + using GETTER = TemplatedMapGetter; + auto &partition_entries = state.GetMap(); const auto partition_indices = FlatVector::GetData(state.partition_indices); partition_entries.clear(); - switch (state.partition_indices.GetVectorType()) { case VectorType::FLAT_VECTOR: for (idx_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - const auto &partition_index = partition_indices[index]; + const auto &partition_index = partition_indices[i]; auto partition_entry = partition_entries.find(partition_index); if (partition_entry == partition_entries.end()) { partition_entries[partition_index] = list_entry_t(0, 1); @@ -215,9 +154,9 @@ void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &st // Early out: check if everything belongs to a single partition if (partition_entries.size() == 1) { // This needs to be initialized, even if we go the short path here - for (idx_t i = 0; i < append_count; i++) { + for (sel_t i = 0; i < append_count; i++) { const auto index = append_sel.get_index(i); - state.reverse_partition_sel[index] = NumericCast(i); + state.reverse_partition_sel[index] = i; } return; } @@ -235,7 +174,7 @@ void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &st auto &reverse_partition_sel = state.reverse_partition_sel; for (idx_t i = 0; i < append_count; i++) { const auto index = append_sel.get_index(i); - const auto &partition_index = partition_indices[index]; + const auto &partition_index = partition_indices[i]; auto &partition_offset = partition_entries[partition_index].offset; reverse_partition_sel[index] = UnsafeNumericCast(partition_offset); partition_sel[partition_offset++] = UnsafeNumericCast(index); @@ -244,16 +183,16 @@ void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &st void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { if (UseFixedSizeMap()) { - BuildBufferSpace, FixedSizeMapGetter>( - state, state.fixed_partition_entries); + BuildBufferSpace(state); } else { - BuildBufferSpace, UnorderedMapGetter>>( - state, state.partition_entries); + BuildBufferSpace(state); } } -template -void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state, const MAP_TYPE &partition_entries) { +template +void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { + using GETTER = TemplatedMapGetter; + const auto &partition_entries = state.GetMap(); for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { const auto &partition_index = GETTER::GetKey(it); diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp index 72299232..b178b7fb 100644 --- a/src/duckdb/src/common/types/row/row_data_collection.cpp +++ b/src/duckdb/src/common/types/row/row_data_collection.cpp @@ -6,7 +6,7 @@ RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_ bool keep_pinned) : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > Storage::BLOCK_SIZE); + D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); } idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, @@ -114,7 +114,7 @@ void RowDataCollection::Merge(RowDataCollection &other) { if (other.count == 0) { return; } - RowDataCollection temp(buffer_manager, Storage::BLOCK_SIZE, 1); + RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); { // One lock at a time to avoid deadlocks lock_guard read_lock(other.rdc_lock); diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp index efbd072e..9b3a4be0 100644 --- a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp +++ b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp @@ -101,7 +101,7 @@ void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block // Finally, we allocate a new heap block and copy data to it swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, (idx_t)Storage::BLOCK_SIZE), 1U)); + MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); auto new_heap_ptr = new_heap_handle.Ptr(); for (auto &ptr_and_size : ptrs_and_sizes) { @@ -183,7 +183,7 @@ RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, Ro ValidateUnscannedBlock(); } -void RowDataCollectionScanner::SwizzleBlock(RowDataBlock &data_block, RowDataBlock &heap_block) { +void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { // Pin the data block and swizzle the pointers within the rows D_ASSERT(!data_block.block->IsSwizzled()); auto data_handle = rows.buffer_manager.Pin(data_block.block); @@ -198,6 +198,22 @@ void RowDataCollectionScanner::SwizzleBlock(RowDataBlock &data_block, RowDataBlo RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); } +void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { + if (rows.count == 0) { + return; + } + + if (!unswizzling) { + // No swizzled blocks! + return; + } + + auto &data_block = rows.blocks[block_idx]; + if (data_block->block && !data_block->block->IsSwizzled()) { + SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); + } +} + void RowDataCollectionScanner::ReSwizzle() { if (rows.count == 0) { return; @@ -212,7 +228,7 @@ void RowDataCollectionScanner::ReSwizzle() { for (idx_t i = 0; i < rows.blocks.size(); ++i) { auto &data_block = rows.blocks[i]; if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlock(*data_block, *heap.blocks[i]); + SwizzleBlockInternal(*data_block, *heap.blocks[i]); } } } @@ -297,7 +313,7 @@ void RowDataCollectionScanner::Scan(DataChunk &chunk) { for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { auto &data_block = rows.blocks[i]; if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlock(*data_block, *heap.blocks[i]); + SwizzleBlockInternal(*data_block, *heap.blocks[i]); } } } diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp index cb6f0bb1..f00c10fa 100644 --- a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/fast_mem.hpp" #include "duckdb/common/types/row/tuple_data_segment.hpp" #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/storage/buffer/block_handle.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { @@ -13,7 +14,7 @@ TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p) buffer_manager.Allocate(MemoryTag::HASH_TABLE, capacity, false, &handle); } -TupleDataBlock::TupleDataBlock(TupleDataBlock &&other) noexcept { +TupleDataBlock::TupleDataBlock(TupleDataBlock &&other) noexcept : capacity(0), size(0) { std::swap(handle, other.handle); std::swap(capacity, other.capacity); std::swap(size, other.size); @@ -34,6 +35,23 @@ TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) : buffer_manager(allocator.buffer_manager), layout(allocator.layout.Copy()) { } +void TupleDataAllocator::SetDestroyBufferUponUnpin() { + for (auto &block : row_blocks) { + if (block.handle) { + block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + } + } + for (auto &block : heap_blocks) { + if (block.handle) { + block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + } + } +} + +TupleDataAllocator::~TupleDataAllocator() { + SetDestroyBufferUponUnpin(); +} + BufferManager &TupleDataAllocator::GetBufferManager() { return buffer_manager; } @@ -118,10 +136,11 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta TupleDataChunk &chunk) { D_ASSERT(append_count != 0); TupleDataChunkPart result(*chunk.lock); + const auto block_size = buffer_manager.GetBlockSize(); // Allocate row block (if needed) if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { - row_blocks.emplace_back(buffer_manager, (idx_t)Storage::BLOCK_SIZE); + row_blocks.emplace_back(buffer_manager, block_size); } result.row_block_index = NumericCast(row_blocks.size() - 1); auto &row_block = row_blocks[result.row_block_index]; @@ -142,9 +161,8 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta if (total_heap_size == 0) { result.SetHeapEmpty(); } else { - const auto heap_remaining = MaxValue(heap_blocks.empty() ? (idx_t)Storage::BLOCK_SIZE - : heap_blocks.back().RemainingCapacity(), - heap_sizes[append_offset]); + const auto heap_remaining = MaxValue( + heap_blocks.empty() ? block_size : heap_blocks.back().RemainingCapacity(), heap_sizes[append_offset]); if (total_heap_size <= heap_remaining) { // Everything fits @@ -167,7 +185,7 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta } else { // Allocate heap block (if needed) if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { - const auto size = MaxValue((idx_t)Storage::BLOCK_SIZE, heap_sizes[append_offset]); + const auto size = MaxValue(block_size, heap_sizes[append_offset]); heap_blocks.emplace_back(buffer_manager, size); } result.heap_block_index = NumericCast(heap_blocks.size() - 1); @@ -440,7 +458,10 @@ void TupleDataAllocator::ReleaseOrStoreHandlesInternal( case TupleDataPinProperties::ALREADY_PINNED: break; case TupleDataPinProperties::DESTROY_AFTER_DONE: - blocks[block_id].handle = nullptr; + // Prevent it from being added to the eviction queue + blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + // Destroy + blocks[block_id].handle.reset(); break; default: D_ASSERT(properties == TupleDataPinProperties::INVALID); diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index 86cbb144..a5215d03 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/fast_mem.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/common/row_operations/row_operations.hpp" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/common/types/row/tuple_data_allocator.hpp" #include @@ -163,7 +164,7 @@ void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, for (auto &col : column_ids) { auto &type = types[col]; - if (type.Contains(LogicalTypeId::ARRAY)) { + if (TypeVisitor::Contains(type, LogicalTypeId::ARRAY)) { auto cast_type = ArrayType::ConvertToList(type); chunk_state.cached_cast_vector_cache.push_back( make_uniq(Allocator::DefaultAllocator(), cast_type)); @@ -226,15 +227,13 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector auto &entries = StructVector::GetEntries(vector); D_ASSERT(format.children.size() == entries.size()); for (idx_t struct_col_idx = 0; struct_col_idx < entries.size(); struct_col_idx++) { - ToUnifiedFormatInternal(reinterpret_cast(format.children[struct_col_idx]), - *entries[struct_col_idx], count); + ToUnifiedFormatInternal(format.children[struct_col_idx], *entries[struct_col_idx], count); } break; } case PhysicalType::LIST: D_ASSERT(format.children.size() == 1); - ToUnifiedFormatInternal(reinterpret_cast(format.children[0]), - ListVector::GetEntry(vector), ListVector::GetListSize(vector)); + ToUnifiedFormatInternal(format.children[0], ListVector::GetEntry(vector), ListVector::GetListSize(vector)); break; case PhysicalType::ARRAY: { D_ASSERT(format.children.size() == 1); @@ -246,19 +245,20 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector // How many list_entry_t's do we need to cover the whole child array? // Make sure we round up so its all covered auto child_array_total_size = ArrayVector::GetTotalSize(vector); - auto list_entry_t_count = MaxValue((child_array_total_size + array_size) / array_size, count); + auto list_entry_t_count = + MaxValue((child_array_total_size + array_size) / array_size, format.unified.validity.TargetCount()); // Create list entries! - format.array_list_entries = make_uniq_array(list_entry_t_count); + format.array_list_entries = make_unsafe_uniq_array(list_entry_t_count); for (idx_t i = 0; i < list_entry_t_count; i++) { format.array_list_entries[i].length = array_size; format.array_list_entries[i].offset = i * array_size; } format.unified.data = reinterpret_cast(format.array_list_entries.get()); - ToUnifiedFormatInternal(reinterpret_cast(format.children[0]), - ArrayVector::GetEntry(vector), count * array_size); - } break; + ToUnifiedFormatInternal(format.children[0], ArrayVector::GetEntry(vector), child_array_total_size); + break; + } default: break; } @@ -384,6 +384,17 @@ void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { chunk.Initialize(allocator->GetAllocator(), layout.GetTypes()); } +void TupleDataCollection::InitializeChunk(DataChunk &chunk, const vector &columns) const { + vector chunk_types(columns.size()); + // keep the order of the columns + for (idx_t i = 0; i < columns.size(); i++) { + auto column_idx = columns[i]; + D_ASSERT(column_idx < layout.ColumnCount()); + chunk_types[i] = layout.GetTypes()[column_idx]; + } + chunk.Initialize(allocator->GetAllocator(), chunk_types); +} + void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { auto &column_ids = state.chunk_state.column_ids; D_ASSERT(!column_ids.empty()); @@ -419,7 +430,7 @@ void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector(Allocator::DefaultAllocator(), cast_type)); @@ -525,14 +536,18 @@ void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChu segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); result.Reset(); + ResetCachedCastVectors(chunk_state, column_ids); + Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result, + *FlatVector::IncrementalSelectionVector(), chunk_state.cached_cast_vectors); + result.SetCardinality(chunk.count); +} + +void TupleDataCollection::ResetCachedCastVectors(TupleDataChunkState &chunk_state, const vector &column_ids) { for (idx_t i = 0; i < column_ids.size(); i++) { if (chunk_state.cached_cast_vectors[i]) { chunk_state.cached_cast_vectors[i]->ResetFromCache(*chunk_state.cached_cast_vector_cache[i]); } } - Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result, - *FlatVector::IncrementalSelectionVector(), chunk_state.cached_cast_vectors); - result.SetCardinality(chunk.count); } // LCOV_EXCL_START diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp index 341f239f..fb9e8a6d 100644 --- a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/common/fast_mem.hpp" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/common/types/null_value.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" #include "duckdb/common/uhugeint.hpp" @@ -19,8 +20,8 @@ constexpr idx_t TupleDataWithinListFixedSize() { } template -static inline void TupleDataValueStore(const T &source, const data_ptr_t &row_location, const idx_t offset_in_row, - data_ptr_t &heap_location) { +static void TupleDataValueStore(const T &source, const data_ptr_t &row_location, const idx_t offset_in_row, + data_ptr_t &) { Store(source, row_location + offset_in_row); } @@ -33,7 +34,7 @@ inline void TupleDataValueStore(const string_t &source, const data_ptr_t &row_lo if (source.IsInlined()) { Store(source, row_location + offset_in_row); } else { - memcpy(heap_location, source.GetData(), source.GetSize()); + FastMemcpy(heap_location, source.GetData(), source.GetSize()); Store(string_t(const_char_ptr_cast(heap_location), UnsafeNumericCast(source.GetSize())), row_location + offset_in_row); heap_location += source.GetSize(); @@ -41,8 +42,7 @@ inline void TupleDataValueStore(const string_t &source, const data_ptr_t &row_lo } template -static inline void TupleDataWithinListValueStore(const T &source, const data_ptr_t &location, - data_ptr_t &heap_location) { +static void TupleDataWithinListValueStore(const T &source, const data_ptr_t &location, data_ptr_t &) { Store(source, location); } @@ -52,13 +52,13 @@ inline void TupleDataWithinListValueStore(const string_t &source, const data_ptr #ifdef DEBUG source.VerifyCharacters(); #endif - Store(NumericCast(source.GetSize()), location); - memcpy(heap_location, source.GetData(), source.GetSize()); + Store(UnsafeNumericCast(source.GetSize()), location); + FastMemcpy(heap_location, source.GetData(), source.GetSize()); heap_location += source.GetSize(); } template -inline void TupleDataValueVerify(const LogicalType &type, const T &value) { +void TupleDataValueVerify(const LogicalType &, const T &) { #ifdef DEBUG // NOP #endif @@ -74,7 +74,7 @@ inline void TupleDataValueVerify(const LogicalType &type, const string_t &value) } template -static inline T TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { +static T TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &) { return Load(location); } @@ -86,7 +86,7 @@ inline string_t TupleDataWithinListValueLoad(const data_ptr_t &location, data_pt return result; } -static inline void ResetCombinedListData(vector &vector_data) { +static void ResetCombinedListData(vector &vector_data) { #ifdef DEBUG for (auto &vd : vector_data) { vd.combined_list_data = nullptr; @@ -100,17 +100,16 @@ void TupleDataCollection::ComputeHeapSizes(TupleDataChunkState &chunk_state, con ResetCombinedListData(chunk_state.vector_data); auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - std::fill_n(heap_sizes, new_chunk.size(), 0); + std::fill_n(heap_sizes, append_count, 0); for (idx_t col_idx = 0; col_idx < new_chunk.ColumnCount(); col_idx++) { auto &source_v = new_chunk.data[col_idx]; auto &source_format = chunk_state.vector_data[col_idx]; - TupleDataCollection::ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, - append_count); + ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, append_count); } } -static inline idx_t StringHeapSize(const string_t &val) { +static idx_t StringHeapSize(const string_t &val) { return val.IsInlined() ? 0 : val.GetSize(); } @@ -150,8 +149,7 @@ void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &s for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { const auto &struct_source = struct_sources[struct_col_idx]; auto &struct_format = source_format.children[struct_col_idx]; - TupleDataCollection::ComputeHeapSizes(heap_sizes_v, *struct_source, struct_format, append_sel, - append_count); + ComputeHeapSizes(heap_sizes_v, *struct_source, struct_format, append_sel, append_count); } break; } @@ -168,8 +166,8 @@ void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &s D_ASSERT(source_format.children.size() == 1); auto &child_source_v = ListVector::GetEntry(source_v); auto &child_format = source_format.children[0]; - TupleDataCollection::WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, - append_count, source_vector_data); + WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, append_count, + source_vector_data); break; } case PhysicalType::ARRAY: { @@ -185,8 +183,8 @@ void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &s D_ASSERT(source_format.children.size() == 1); auto &child_source_v = ArrayVector::GetEntry(source_v); auto &child_format = source_format.children[0]; - TupleDataCollection::WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, - append_count, source_vector_data); + WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, append_count, + source_vector_data); break; } default: @@ -200,26 +198,26 @@ void TupleDataCollection::WithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const UnifiedVectorFormat &list_data) { auto type = source_v.GetType().InternalType(); if (TypeIsConstantSize(type)) { - TupleDataCollection::ComputeFixedWithinCollectionHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); + ComputeFixedWithinCollectionHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, + list_data); return; } switch (type) { case PhysicalType::VARCHAR: - TupleDataCollection::StringWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); + StringWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, + list_data); break; case PhysicalType::STRUCT: - TupleDataCollection::StructWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, - append_count, list_data); + StructWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, + list_data); break; case PhysicalType::LIST: - TupleDataCollection::CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, - append_sel, append_count, list_data); + CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, + list_data); break; case PhysicalType::ARRAY: - TupleDataCollection::CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, - append_sel, append_count, list_data); + CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, + list_data); break; default: throw NotImplementedException("WithinListHeapComputeSizes for %s", EnumUtil::ToString(source_v.GetType().id())); @@ -227,7 +225,7 @@ void TupleDataCollection::WithinCollectionComputeHeapSizes(Vector &heap_sizes_v, } void TupleDataCollection::ComputeFixedWithinCollectionHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, + TupleDataVectorFormat &, const SelectionVector &append_sel, const idx_t append_count, const UnifiedVectorFormat &list_data) { @@ -260,7 +258,7 @@ void TupleDataCollection::ComputeFixedWithinCollectionHeapSizes(Vector &heap_siz } } -void TupleDataCollection::StringWithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, +void TupleDataCollection::StringWithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &, TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, @@ -343,8 +341,8 @@ void TupleDataCollection::StructWithinCollectionComputeHeapSizes(Vector &heap_si auto &struct_source = *struct_sources[struct_col_idx]; auto &struct_format = source_format.children[struct_col_idx]; - TupleDataCollection::WithinCollectionComputeHeapSizes(heap_sizes_v, struct_source, struct_format, append_sel, - append_count, list_data); + WithinCollectionComputeHeapSizes(heap_sizes_v, struct_source, struct_format, append_sel, append_count, + list_data); } } @@ -507,8 +505,56 @@ void TupleDataCollection::CollectionWithinCollectionComputeHeapSizes(Vector &hea combined_child_list_data.validity.Initialize(combined_validity); // Recurse - TupleDataCollection::WithinCollectionComputeHeapSizes(heap_sizes_v, child_source, child_format, append_sel, - append_count, combined_child_list_data); + WithinCollectionComputeHeapSizes(heap_sizes_v, child_source, child_format, append_sel, append_count, + combined_child_list_data); +} + +template +static void TemplatedInitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count) { + for (idx_t i = 0; i < append_count; i++) { + Store(T(-1), row_locations[i]); + } +} + +template +static void TemplatedInitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count) { + for (idx_t i = 0; i < append_count; i++) { + memset(row_locations[i], ~0, validity_bytes); + } +} + +static void InitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count, + const idx_t validity_bytes) { + switch (validity_bytes) { + case 1: + TemplatedInitializeValidityMask(row_locations, append_count); + break; + case 2: + TemplatedInitializeValidityMask(row_locations, append_count); + break; + case 3: + TemplatedInitializeValidityMask<3>(row_locations, append_count); + break; + case 4: + TemplatedInitializeValidityMask(row_locations, append_count); + break; + case 5: + TemplatedInitializeValidityMask<5>(row_locations, append_count); + break; + case 6: + TemplatedInitializeValidityMask<6>(row_locations, append_count); + break; + case 7: + TemplatedInitializeValidityMask<7>(row_locations, append_count); + break; + case 8: + TemplatedInitializeValidityMask(row_locations, append_count); + break; + default: + for (idx_t i = 0; i < append_count; i++) { + FastMemset(row_locations[i], ~0, validity_bytes); + } + } } void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, @@ -516,24 +562,25 @@ void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataCh #ifdef DEBUG Vector heap_locations_copy(LogicalType::POINTER); if (!layout.AllConstant()) { - VectorOperations::Copy(chunk_state.heap_locations, heap_locations_copy, append_count, 0, 0); + const auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); + const auto copied_heap_locations = FlatVector::GetData(heap_locations_copy); + for (idx_t i = 0; i < append_count; i++) { + copied_heap_locations[i] = heap_locations[i]; + } } #endif const auto row_locations = FlatVector::GetData(chunk_state.row_locations); // Set the validity mask for each row before inserting data - const auto validity_bytes = ValidityBytes::SizeInBytes(layout.ColumnCount()); - for (idx_t i = 0; i < append_count; i++) { - FastMemset(row_locations[i], ~0, validity_bytes); - } + InitializeValidityMask(row_locations, append_count, ValidityBytes::SizeInBytes(layout.ColumnCount())); if (!layout.AllConstant()) { // Set the heap size for each row const auto heap_size_offset = layout.GetHeapSizeOffset(); const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); for (idx_t i = 0; i < append_count; i++) { - Store(NumericCast(heap_sizes[i]), row_locations[i] + heap_size_offset); + Store(UnsafeNumericCast(heap_sizes[i]), row_locations[i] + heap_size_offset); } } @@ -549,7 +596,9 @@ void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataCh const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); const auto offset_heap_locations = FlatVector::GetData(chunk_state.heap_locations); for (idx_t i = 0; i < append_count; i++) { - D_ASSERT(offset_heap_locations[i] == original_heap_locations[i] + heap_sizes[i]); + if (heap_sizes[i] != 0) { + D_ASSERT(offset_heap_locations[i] == original_heap_locations[i] + heap_sizes[i]); + } } } #endif @@ -564,11 +613,11 @@ void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const Vector } template -static void TupleDataTemplatedScatter(const Vector &source, const TupleDataVectorFormat &source_format, +static void TupleDataTemplatedScatter(const Vector &, const TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, - const vector &child_functions) { + Vector &heap_locations, const idx_t col_idx, const UnifiedVectorFormat &, + const vector &) { // Source const auto &source_data = source_format.unified; const auto &source_sel = *source_data.sel; @@ -576,8 +625,8 @@ static void TupleDataTemplatedScatter(const Vector &source, const TupleDataVecto const auto &validity = source_data.validity; // Target - auto target_locations = FlatVector::GetData(row_locations); - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_locations = FlatVector::GetData(row_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); // Precompute mask indexes idx_t entry_idx; @@ -614,7 +663,7 @@ static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFo const auto &validity = source_data.validity; // Target - auto target_locations = FlatVector::GetData(row_locations); + const auto target_locations = FlatVector::GetData(row_locations); // Precompute mask indexes idx_t entry_idx; @@ -644,10 +693,8 @@ static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFo D_ASSERT(struct_layout.ColumnCount() == struct_sources.size()); // Set the validity of the entries within the STRUCTs - const auto validity_bytes = ValidityBytes::SizeInBytes(struct_layout.ColumnCount()); - for (idx_t i = 0; i < append_count; i++) { - memset(struct_target_locations[i], ~0, validity_bytes); - } + InitializeValidityMask(struct_target_locations, append_count, + ValidityBytes::SizeInBytes(struct_layout.ColumnCount())); // Recurse through the struct children for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { @@ -666,7 +713,7 @@ static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFo static void TupleDataListScatter(const Vector &source, const TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, + const idx_t col_idx, const UnifiedVectorFormat &, const vector &child_functions) { // Source const auto &source_data = source_format.unified; @@ -675,8 +722,8 @@ static void TupleDataListScatter(const Vector &source, const TupleDataVectorForm const auto &validity = source_data.validity; // Target - auto target_locations = FlatVector::GetData(row_locations); - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_locations = FlatVector::GetData(row_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); // Precompute mask indexes idx_t entry_idx; @@ -714,7 +761,7 @@ static void TupleDataListScatter(const Vector &source, const TupleDataVectorForm static void TupleDataArrayScatter(const Vector &source, const TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, + const idx_t col_idx, const UnifiedVectorFormat &, const vector &child_functions) { // Source // The Array vector has fake list_entry_t's set by this point, so this is fine @@ -724,8 +771,8 @@ static void TupleDataArrayScatter(const Vector &source, const TupleDataVectorFor const auto &validity = source_data.validity; // Target - auto target_locations = FlatVector::GetData(row_locations); - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_locations = FlatVector::GetData(row_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); // Precompute mask indexes idx_t entry_idx; @@ -761,12 +808,11 @@ static void TupleDataArrayScatter(const Vector &source, const TupleDataVectorFor // Collection Scatter //------------------------------------------------------------------------------ template -static void TupleDataTemplatedWithinCollectionScatter(const Vector &source, const TupleDataVectorFormat &source_format, +static void TupleDataTemplatedWithinCollectionScatter(const Vector &, const TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { + const TupleDataLayout &, const Vector &, Vector &heap_locations, + const idx_t, const UnifiedVectorFormat &list_data, + const vector &) { // Parent list data const auto &list_sel = *list_data.sel; const auto list_entries = UnifiedVectorFormat::GetData(list_data); @@ -779,7 +825,7 @@ static void TupleDataTemplatedWithinCollectionScatter(const Vector &source, cons const auto &source_validity = source_data.validity; // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); for (idx_t i = 0; i < append_count; i++) { const auto list_idx = list_sel.get_index(append_sel.get_index(i)); @@ -822,7 +868,7 @@ static void TupleDataTemplatedWithinCollectionScatter(const Vector &source, cons static void TupleDataStructWithinCollectionScatter(const Vector &source, const TupleDataVectorFormat &source_format, const SelectionVector &append_sel, const idx_t append_count, const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, + Vector &heap_locations, const idx_t, const UnifiedVectorFormat &list_data, const vector &child_functions) { // Parent list data @@ -836,7 +882,7 @@ static void TupleDataStructWithinCollectionScatter(const Vector &source, const T const auto &source_validity = source_data.validity; // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); // Initialize the validity of the STRUCTs for (idx_t i = 0; i < append_count; i++) { @@ -900,7 +946,7 @@ static void TupleDataCollectionWithinCollectionScatter(const Vector &child_list, const auto &child_list_validity = child_list_data.validity; // Target - auto target_heap_locations = FlatVector::GetData(heap_locations); + const auto target_heap_locations = FlatVector::GetData(heap_locations); for (idx_t i = 0; i < append_count; i++) { const auto list_idx = list_sel.get_index(append_sel.get_index(i)); @@ -1054,6 +1100,7 @@ void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &s void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, const column_t column_id, Vector &result, const SelectionVector &target_sel, optional_ptr cached_cast_vector) const { + D_ASSERT(!cached_cast_vector || FlatVector::Validity(*cached_cast_vector).AllValid()); // ResetCachedCastVectors const auto &gather_function = gather_functions[column_id]; gather_function.function(layout, row_locations, column_id, scan_sel, scan_count, result, target_sel, cached_cast_vector, gather_function.child_functions); @@ -1063,10 +1110,10 @@ void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &s template static void TupleDataTemplatedGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr dummy_vector, - const vector &child_functions) { + const SelectionVector &target_sel, optional_ptr, + const vector &) { // Source - auto source_locations = FlatVector::GetData(row_locations); + const auto source_locations = FlatVector::GetData(row_locations); // Target auto target_data = FlatVector::GetData(target); @@ -1081,13 +1128,16 @@ static void TupleDataTemplatedGather(const TupleDataLayout &layout, Vector &row_ for (idx_t i = 0; i < scan_count; i++) { const auto &source_row = source_locations[scan_sel.get_index(i)]; const auto target_idx = target_sel.get_index(i); + target_data[target_idx] = Load(source_row + offset_in_row); ValidityBytes row_mask(source_row); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - target_data[target_idx] = Load(source_row + offset_in_row); - TupleDataValueVerify(target.GetType(), target_data[target_idx]); - } else { + if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { target_validity.SetInvalid(target_idx); } +#ifdef DEBUG + else { + TupleDataValueVerify(target.GetType(), target_data[target_idx]); + } +#endif } } @@ -1096,7 +1146,7 @@ static void TupleDataStructGather(const TupleDataLayout &layout, Vector &row_loc const SelectionVector &target_sel, optional_ptr dummy_vector, const vector &child_functions) { // Source - auto source_locations = FlatVector::GetData(row_locations); + const auto source_locations = FlatVector::GetData(row_locations); // Target auto &target_validity = FlatVector::Validity(target); @@ -1145,13 +1195,13 @@ static void TupleDataStructGather(const TupleDataLayout &layout, Vector &row_loc //------------------------------------------------------------------------------ static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr dummy_vector, + const SelectionVector &target_sel, optional_ptr, const vector &child_functions) { // Source - auto source_locations = FlatVector::GetData(row_locations); + const auto source_locations = FlatVector::GetData(row_locations); // Target - auto target_list_entries = FlatVector::GetData(target); + const auto target_list_entries = FlatVector::GetData(target); auto &target_list_validity = FlatVector::Validity(target); // Precompute mask indexes @@ -1161,7 +1211,7 @@ static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locat // Load pointers to the data from the row Vector heap_locations(LogicalType::POINTER); - auto source_heap_locations = FlatVector::GetData(heap_locations); + const auto source_heap_locations = FlatVector::GetData(heap_locations); const auto offset_in_row = layout.GetOffsets()[col_idx]; uint64_t target_list_offset = 0; @@ -1202,21 +1252,20 @@ static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locat // Collection Gather //------------------------------------------------------------------------------ template -static void TupleDataTemplatedWithinCollectionGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, - optional_ptr list_vector, - const vector &child_functions) { +static void +TupleDataTemplatedWithinCollectionGather(const TupleDataLayout &, Vector &heap_locations, const idx_t list_size_before, + const SelectionVector &, const idx_t scan_count, Vector &target, + const SelectionVector &target_sel, optional_ptr list_vector, + const vector &) { // List parent const auto list_entries = FlatVector::GetData(*list_vector); const auto &list_validity = FlatVector::Validity(*list_vector); // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); + const auto source_heap_locations = FlatVector::GetData(heap_locations); // Target - auto target_data = FlatVector::GetData(target); + const auto target_data = FlatVector::GetData(target); auto &target_validity = FlatVector::Validity(target); uint64_t target_offset = list_size_before; @@ -1265,7 +1314,7 @@ static void TupleDataStructWithinCollectionGather(const TupleDataLayout &layout, const auto &list_validity = FlatVector::Validity(*list_vector); // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); + const auto source_heap_locations = FlatVector::GetData(heap_locations); // Target auto &target_validity = FlatVector::Validity(target); @@ -1317,17 +1366,17 @@ static void TupleDataCollectionWithinCollectionGather(const TupleDataLayout &lay const auto &list_validity = FlatVector::Validity(*list_vector); // Source - auto source_heap_locations = FlatVector::GetData(heap_locations); + const auto source_heap_locations = FlatVector::GetData(heap_locations); // Target - auto target_list_entries = FlatVector::GetData(target); + const auto target_list_entries = FlatVector::GetData(target); auto &target_validity = FlatVector::Validity(target); const auto child_list_size_before = ListVector::GetListSize(target); // We need to create a vector that has the combined list sizes (hugeint_t has same size as list_entry_t) Vector combined_list_vector(LogicalType::HUGEINT); FlatVector::SetValidity(combined_list_vector, list_validity); // Has same validity as list parent - auto combined_list_entries = FlatVector::GetData(combined_list_vector); + const auto combined_list_entries = FlatVector::GetData(combined_list_vector); uint64_t target_offset = list_size_before; uint64_t target_child_offset = child_list_size_before; @@ -1337,8 +1386,13 @@ static void TupleDataCollectionWithinCollectionGather(const TupleDataLayout &lay continue; } + // Set the offset of the combined list entry + auto &combined_list_entry = combined_list_entries[target_idx]; + combined_list_entry.offset = target_child_offset; + const auto &list_length = list_entries[target_idx].length; if (list_length == 0) { + combined_list_entry.length = 0; continue; } @@ -1351,10 +1405,6 @@ static void TupleDataCollectionWithinCollectionGather(const TupleDataLayout &lay const auto source_data_location = source_heap_location; source_heap_location += list_length * sizeof(uint64_t); - // Set the offset of the combined list entry - auto &combined_list_entry = combined_list_entries[target_sel.get_index(i)]; - combined_list_entry.offset = target_child_offset; - // Load the child validity and data belonging to this list entry for (idx_t child_i = 0; child_i < list_length; child_i++) { if (source_mask.RowIsValidUnsafe(child_i)) { @@ -1392,7 +1442,6 @@ static void TupleDataCastToArrayListGather(const TupleDataLayout &layout, Vector const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, const SelectionVector &target_sel, optional_ptr cached_cast_vector, const vector &child_functions) { - if (cached_cast_vector) { // Reuse the cached cast vector TupleDataListGather(layout, row_locations, col_idx, scan_sel, scan_count, *cached_cast_vector, target_sel, @@ -1508,23 +1557,24 @@ TupleDataGatherFunction TupleDataCollection::GetGatherFunction(const LogicalType return TupleDataGetGatherFunctionInternal(type, false); } - if (type.Contains(LogicalTypeId::ARRAY)) { + if (TypeVisitor::Contains(type, LogicalTypeId::ARRAY)) { // Special case: we cant handle arrays yet, so we need to replace them with lists when gathering - auto new_type = ArrayType::ConvertToList(type); + const auto new_type = ArrayType::ConvertToList(type); TupleDataGatherFunction result; // Theres only two cases: Either the array is within a struct, or it is within a list (or has now become a list) - if (new_type.InternalType() == PhysicalType::LIST) { + switch (new_type.InternalType()) { + case PhysicalType::LIST: result.function = TupleDataCastToArrayListGather; result.child_functions.push_back( TupleDataGetGatherFunctionInternal(ListType::GetChildType(new_type), true)); return result; - } else if (new_type.InternalType() == PhysicalType::STRUCT) { + case PhysicalType::STRUCT: result.function = TupleDataCastToArrayStructGather; for (const auto &child_type : StructType::GetChildTypes(new_type)) { result.child_functions.push_back(TupleDataGetGatherFunctionInternal(child_type.second, false)); } return result; - } else { + default: throw InternalException("Unsupported type for TupleDataCollection::GetGatherFunction"); } } diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp index eb38913b..82e25001 100644 --- a/src/duckdb/src/common/types/row/tuple_data_segment.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_segment.cpp @@ -113,9 +113,12 @@ TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) TupleDataSegment::~TupleDataSegment() { lock_guard guard(pinned_handles_lock); + if (allocator) { + allocator->SetDestroyBufferUponUnpin(); // Prevent blocks from being added to eviction queue + } pinned_row_handles.clear(); pinned_heap_handles.clear(); - allocator = nullptr; + allocator.reset(); } void SwapTupleDataSegment(TupleDataSegment &a, TupleDataSegment &b) { diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp index 8458d4c6..7b6fda41 100644 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -1,11 +1,12 @@ #include "duckdb/common/types/selection_vector.hpp" + #include "duckdb/common/printer.hpp" #include "duckdb/common/to_string.hpp" namespace duckdb { SelectionData::SelectionData(idx_t count) { - owned_data = make_unsafe_uniq_array(count); + owned_data = make_unsafe_uniq_array_uninitialized(count); #ifdef DEBUG for (idx_t i = 0; i < count; i++) { owned_data[i] = std::numeric_limits::max(); @@ -43,4 +44,19 @@ buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx return data; } +void SelectionVector::Verify(idx_t count, idx_t vector_size) const { +#ifdef DEBUG + D_ASSERT(vector_size >= 1); + for (idx_t i = 0; i < count; i++) { + auto index = get_index(i); + if (index >= vector_size) { + throw InternalException( + "Provided SelectionVector is invalid, index %d points to %d, which is out of range. " + "the valid range (0-%d)", + i, index, vector_size - 1); + } + } +#endif +} + } // namespace duckdb diff --git a/src/duckdb/src/common/types/time.cpp b/src/duckdb/src/common/types/time.cpp index 8d2ba7fc..fa4d135f 100644 --- a/src/duckdb/src/common/types/time.cpp +++ b/src/duckdb/src/common/types/time.cpp @@ -14,14 +14,14 @@ #include namespace duckdb { - static_assert(sizeof(dtime_t) == sizeof(int64_t), "dtime_t was padded"); // string format is hh:mm:ss.microsecondsZ // microseconds and Z are optional // ISO 8601 -bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { +bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, + optional_ptr nanos) { int32_t hour = -1, min = -1, sec = -1, micros = -1; pos = 0; @@ -64,27 +64,38 @@ bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &r // invalid separator return false; } - - if (!Date::ParseDoubleDigit(buf, len, pos, min)) { - return false; - } - if (min < 0 || min >= 60) { - return false; - } - - if (pos >= len) { - return false; + idx_t sep_pos = pos; + if (pos == len && !strict) { + min = 0; + } else { + if (!Date::ParseDoubleDigit(buf, len, pos, min)) { + return false; + } + if (min < 0 || min >= 60) { + return false; + } } - if (buf[pos++] != sep) { + if (pos > len) { return false; } + if (pos == len && (!strict || sep_pos + 2 == pos)) { + sec = 0; + } else { + if (buf[pos++] != sep) { + return false; + } - if (!Date::ParseDoubleDigit(buf, len, pos, sec)) { - return false; - } - if (sec < 0 || sec >= 60) { - return false; + if (pos == len && !strict) { + sec = 0; + } else { + if (!Date::ParseDoubleDigit(buf, len, pos, sec)) { + return false; + } + if (sec < 0 || sec >= 60) { + return false; + } + } } micros = 0; @@ -92,11 +103,19 @@ bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &r pos++; // we expect some microseconds int32_t mult = 100000; + if (nanos) { + // do we expect nanoseconds? + mult *= Interval::NANOS_PER_MICRO; + } for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++, mult /= 10) { if (mult > 0) { micros += (buf[pos] - '0') * mult; } } + if (nanos) { + *nanos = UnsafeNumericCast(micros % Interval::NANOS_PER_MICRO); + micros /= Interval::NANOS_PER_MICRO; + } } // in strict mode, check remaining string for non-space characters @@ -115,16 +134,18 @@ bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &r return true; } -bool Time::TryConvertInterval(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { - return Time::TryConvertInternal(buf, len, pos, result, strict); +bool Time::TryConvertInterval(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, + optional_ptr nanos) { + return Time::TryConvertInternal(buf, len, pos, result, strict, nanos); } -bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict) { - if (!Time::TryConvertInternal(buf, len, pos, result, strict)) { +bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, + optional_ptr nanos) { + if (!Time::TryConvertInternal(buf, len, pos, result, strict, nanos)) { if (!strict) { // last chance, check if we can parse as timestamp timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { + if (Timestamp::TryConvertTimestamp(buf, len, timestamp, nanos) == TimestampCastResult::SUCCESS) { if (!Timestamp::IsFinite(timestamp)) { return false; } @@ -137,14 +158,15 @@ bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &resul return result.micros <= Interval::MICROS_PER_DAY; } -bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, bool &has_offset, bool strict) { +bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, bool &has_offset, bool strict, + optional_ptr nanos) { dtime_t time_part; has_offset = false; - if (!Time::TryConvertInternal(buf, len, pos, time_part, false)) { + if (!Time::TryConvertInternal(buf, len, pos, time_part, false, nanos)) { if (!strict) { // last chance, check if we can parse as timestamp timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp) == TimestampCastResult::SUCCESS) { + if (Timestamp::TryConvertTimestamp(buf, len, timestamp, nanos) == TimestampCastResult::SUCCESS) { if (!Timestamp::IsFinite(timestamp)) { return false; } @@ -218,17 +240,17 @@ string Time::ConversionError(string_t str) { return Time::ConversionError(str.GetString()); } -dtime_t Time::FromCString(const char *buf, idx_t len, bool strict) { +dtime_t Time::FromCString(const char *buf, idx_t len, bool strict, optional_ptr nanos) { dtime_t result; idx_t pos; - if (!Time::TryConvertTime(buf, len, pos, result, strict)) { + if (!Time::TryConvertTime(buf, len, pos, result, strict, nanos)) { throw ConversionException(ConversionError(string(buf, len))); } return result; } -dtime_t Time::FromString(const string &str, bool strict) { - return Time::FromCString(str.c_str(), str.size(), strict); +dtime_t Time::FromString(const string &str, bool strict, optional_ptr nanos) { + return Time::FromCString(str.c_str(), str.size(), strict, nanos); } string Time::ToString(dtime_t time) { @@ -237,7 +259,7 @@ string Time::ToString(dtime_t time) { char micro_buffer[6]; auto length = TimeToStringCast::Length(time_units, micro_buffer); - auto buffer = make_unsafe_uniq_array(length); + auto buffer = make_unsafe_uniq_array_uninitialized(length); TimeToStringCast::Format(buffer.get(), length, time_units, micro_buffer); return string(buffer.get(), length); } @@ -273,6 +295,15 @@ dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t mic return dtime_t(result); } +int64_t Time::ToNanoTime(int32_t hour, int32_t minute, int32_t second, int32_t nanoseconds) { + int64_t result; + result = hour; // hours + result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes + result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds + result = result * Interval::NANOS_PER_SEC + nanoseconds; // seconds -> nanoseconds + return result; +} + bool Time::IsValidTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { if (hour < 0 || hour >= 24) { return (hour == 24) && (minute == 0) && (second == 0) && (microseconds == 0); diff --git a/src/duckdb/src/common/types/timestamp.cpp b/src/duckdb/src/common/types/timestamp.cpp index e0f14a5e..d6b3f822 100644 --- a/src/duckdb/src/common/types/timestamp.cpp +++ b/src/duckdb/src/common/types/timestamp.cpp @@ -55,7 +55,8 @@ timestamp_t ×tamp_t::operator-=(const int64_t &delta) { return *this; } -bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, string_t &tz) { +bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, string_t &tz, + optional_ptr nanos) { idx_t pos; date_t date; dtime_t time; @@ -82,7 +83,7 @@ bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &r // TryConvertTime may recursively call us, so we opt for a stricter // operation. Note that we can't pass strict== true here because we // want to process any suffix. - if (!Time::TryConvertInterval(str + pos, len - pos, time_pos, time)) { + if (!Time::TryConvertInterval(str + pos, len - pos, time_pos, time, false, nanos)) { return false; } // We parsed an interval, so make sure it is in range. @@ -132,11 +133,12 @@ bool Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &r return true; } -TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result) { +TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result, + optional_ptr nanos) { string_t tz(nullptr, 0); bool has_offset = false; // We don't understand TZ without an extension, so fail if one was provided. - auto success = TryConvertTimestampTZ(str, len, result, has_offset, tz); + auto success = TryConvertTimestampTZ(str, len, result, has_offset, tz, nanos); if (!success) { return TimestampCastResult::ERROR_INCORRECT_FORMAT; } @@ -155,6 +157,31 @@ TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, t return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; } +bool Timestamp::TryFromTimestampNanos(timestamp_t input, int32_t nanos, timestamp_ns_t &result) { + if (!IsFinite(input)) { + result.value = input.value; + return true; + } + // Scale to ns + if (!TryMultiplyOperator::Operation(input.value, Interval::NANOS_PER_MICRO, result.value)) { + return false; + } + + return TryAddOperator::Operation(result.value, int64_t(nanos), result.value); +} + +TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_ns_t &result) { + int32_t nanos = 0; + auto success = TryConvertTimestamp(str, len, result, &nanos); + if (success != TimestampCastResult::SUCCESS) { + return success; + } + if (!TryFromTimestampNanos(result, nanos, result)) { + return TimestampCastResult::ERROR_INCORRECT_FORMAT; + } + return TimestampCastResult::SUCCESS; +} + string Timestamp::ConversionError(const string &str) { return StringUtil::Format("timestamp field value out of range: \"%s\", " "expected format is (YYYY-MM-DD HH:MM:SS[.US][±HH:MM| ZONE])", @@ -175,9 +202,9 @@ string Timestamp::UnsupportedTimezoneError(string_t str) { return Timestamp::UnsupportedTimezoneError(str.GetString()); } -timestamp_t Timestamp::FromCString(const char *str, idx_t len) { +timestamp_t Timestamp::FromCString(const char *str, idx_t len, optional_ptr nanos) { timestamp_t result; - auto cast_result = Timestamp::TryConvertTimestamp(str, len, result); + auto cast_result = Timestamp::TryConvertTimestamp(str, len, result, nanos); if (cast_result == TimestampCastResult::SUCCESS) { return result; } @@ -253,12 +280,13 @@ string Timestamp::ToString(timestamp_t timestamp) { } date_t Timestamp::GetDate(timestamp_t timestamp) { - if (timestamp == timestamp_t::infinity()) { + if (DUCKDB_UNLIKELY(timestamp == timestamp_t::infinity())) { return date_t::infinity(); - } else if (timestamp == timestamp_t::ninfinity()) { + } else if (DUCKDB_UNLIKELY(timestamp == timestamp_t::ninfinity())) { return date_t::ninfinity(); } - return date_t((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - (timestamp.value < 0)); + return date_t(UnsafeNumericCast((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - + (timestamp.value < 0))); } dtime_t Timestamp::GetTime(timestamp_t timestamp) { @@ -310,6 +338,19 @@ void Timestamp::Convert(timestamp_t timestamp, date_t &out_date, dtime_t &out_ti D_ASSERT(timestamp == Timestamp::FromDatetime(out_date, out_time)); } +void Timestamp::Convert(timestamp_ns_t input, date_t &out_date, dtime_t &out_time, int32_t &out_nanos) { + timestamp_t ms(input.value / Interval::NANOS_PER_MICRO); + out_date = Timestamp::GetDate(ms); + int64_t days_nanos; + if (!TryMultiplyOperator::Operation(out_date.days, Interval::NANOS_PER_DAY, + days_nanos)) { + throw ConversionException("Date out of range in timestamp_ns conversion"); + } + + out_time = dtime_t((input.value - days_nanos) / Interval::NANOS_PER_MICRO); + out_nanos = UnsafeNumericCast((input.value - days_nanos) % Interval::NANOS_PER_MICRO); +} + timestamp_t Timestamp::GetCurrentTimestamp() { auto now = system_clock::now(); auto epoch_ms = duration_cast(now.time_since_epoch()).count(); @@ -347,7 +388,7 @@ timestamp_t Timestamp::FromEpochMicroSeconds(int64_t micros) { } timestamp_t Timestamp::FromEpochNanoSecondsPossiblyInfinite(int64_t ns) { - return timestamp_t(ns / 1000); + return timestamp_t(ns / Interval::NANOS_PER_MICRO); } timestamp_t Timestamp::FromEpochNanoSeconds(int64_t ns) { @@ -355,6 +396,24 @@ timestamp_t Timestamp::FromEpochNanoSeconds(int64_t ns) { return FromEpochNanoSecondsPossiblyInfinite(ns); } +timestamp_ns_t Timestamp::TimestampNsFromEpochMillis(int64_t millis) { + D_ASSERT(Timestamp::IsFinite(timestamp_t(millis))); + timestamp_ns_t result; + if (!TryMultiplyOperator::Operation(millis, Interval::NANOS_PER_MICRO, result.value)) { + throw ConversionException("Could not convert Timestamp(US) to Timestamp(NS)"); + } + return result; +} + +timestamp_ns_t Timestamp::TimestampNsFromEpochMicros(int64_t micros) { + D_ASSERT(Timestamp::IsFinite(timestamp_t(micros))); + timestamp_ns_t result; + if (!TryMultiplyOperator::Operation(micros, Interval::NANOS_PER_MSEC, result.value)) { + throw ConversionException("Could not convert Timestamp(MS) to Timestamp(NS)"); + } + return result; +} + int64_t Timestamp::GetEpochSeconds(timestamp_t timestamp) { D_ASSERT(Timestamp::IsFinite(timestamp)); return timestamp.value / Interval::MICROS_PER_SEC; @@ -370,9 +429,8 @@ int64_t Timestamp::GetEpochMicroSeconds(timestamp_t timestamp) { } bool Timestamp::TryGetEpochNanoSeconds(timestamp_t timestamp, int64_t &result) { - constexpr static const int64_t NANOSECONDS_IN_MICROSECOND = 1000; D_ASSERT(Timestamp::IsFinite(timestamp)); - if (!TryMultiplyOperator::Operation(timestamp.value, NANOSECONDS_IN_MICROSECOND, result)) { + if (!TryMultiplyOperator::Operation(timestamp.value, Interval::NANOS_PER_MICRO, result)) { return false; } return true; diff --git a/src/duckdb/src/common/types/uuid.cpp b/src/duckdb/src/common/types/uuid.cpp index 82583fe9..eb50f209 100644 --- a/src/duckdb/src/common/types/uuid.cpp +++ b/src/duckdb/src/common/types/uuid.cpp @@ -49,7 +49,7 @@ bool UUID::FromString(const string &str, hugeint_t &result) { count++; } // Flip the first bit to make `order by uuid` same as `order by uuid::varchar` - result.upper ^= (uint64_t(1) << 63); + result.upper ^= NumericLimits::Minimum(); return count == 32; } diff --git a/src/duckdb/src/common/types/validity_mask.cpp b/src/duckdb/src/common/types/validity_mask.cpp index d91c7a73..fac6fea4 100644 --- a/src/duckdb/src/common/types/validity_mask.cpp +++ b/src/duckdb/src/common/types/validity_mask.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/serializer/read_stream.hpp" +#include "duckdb/common/types/selection_vector.hpp" namespace duckdb { @@ -71,7 +72,7 @@ void ValidityMask::Resize(idx_t old_size, idx_t new_size) { } } -idx_t ValidityMask::TargetCount() { +idx_t ValidityMask::TargetCount() const { return target_count; } @@ -94,20 +95,54 @@ bool ValidityMask::IsAligned(idx_t count) { return count % BITS_PER_VALUE == 0; } +void ValidityMask::CopySel(const ValidityMask &other, const SelectionVector &sel, idx_t source_offset, + idx_t target_offset, idx_t copy_count) { + if (!other.IsMaskSet() && !IsMaskSet()) { + // no need to copy anything if neither has any null values + return; + } + + if (!sel.IsSet() && IsAligned(source_offset) && IsAligned(target_offset)) { + // common case where we are shifting into an aligned mask using a flat vector + SliceInPlace(other, target_offset, source_offset, copy_count); + return; + } + for (idx_t i = 0; i < copy_count; i++) { + auto source_idx = sel.get_index(source_offset + i); + Set(target_offset + i, other.RowIsValid(source_idx)); + } +} + void ValidityMask::SliceInPlace(const ValidityMask &other, idx_t target_offset, idx_t source_offset, idx_t count) { EnsureWritable(); + const idx_t ragged = count % BITS_PER_VALUE; + const idx_t entire_units = count / BITS_PER_VALUE; if (IsAligned(source_offset) && IsAligned(target_offset)) { auto target_validity = GetData(); auto source_validity = other.GetData(); auto source_offset_entries = EntryCount(source_offset); auto target_offset_entries = EntryCount(target_offset); - memcpy(target_validity + target_offset_entries, source_validity + source_offset_entries, - sizeof(validity_t) * EntryCount(count)); + if (!source_validity) { + // if source has no validity mask - set all bytes to 1 + memset(target_validity + target_offset_entries, 0xFF, sizeof(validity_t) * entire_units); + } else { + memcpy(target_validity + target_offset_entries, source_validity + source_offset_entries, + sizeof(validity_t) * entire_units); + } + if (ragged) { + auto src_entry = + source_validity ? source_validity[source_offset_entries + entire_units] : ValidityBuffer::MAX_ENTRY; + src_entry &= (ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - ragged)); + + target_validity += target_offset_entries + entire_units; + auto tgt_entry = *target_validity; + tgt_entry &= (ValidityBuffer::MAX_ENTRY << ragged); + + *target_validity = tgt_entry | src_entry; + } return; } else if (IsAligned(target_offset)) { // Simple common case where we are shifting into an aligned mask (e.g., 0 in Slice above) - const idx_t entire_units = count / BITS_PER_VALUE; - const idx_t ragged = count % BITS_PER_VALUE; const idx_t tail = source_offset % BITS_PER_VALUE; const idx_t head = BITS_PER_VALUE - tail; auto source_validity = other.GetData() + (source_offset / BITS_PER_VALUE); diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index 9fae9858..5d893b6c 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -29,7 +29,7 @@ #include "duckdb/common/types/hash.hpp" #include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/main/error_manager.hpp" - +#include "duckdb/common/types/varint.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -240,8 +240,13 @@ Value Value::MinimumValue(const LogicalType &type) { const auto min_us = MinimumValue(LogicalType::TIMESTAMP).GetValue(); return Value::TIMESTAMPMS(timestamp_t(Timestamp::GetEpochMs(min_us))); } - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_t(NumericLimits::Minimum())); + case LogicalTypeId::TIMESTAMP_NS: { + // Clear the fractional day. + auto min_ns = NumericLimits::Minimum(); + min_ns /= Interval::NANOS_PER_DAY; + min_ns *= Interval::NANOS_PER_DAY; + return Value::TIMESTAMPNS(timestamp_t(min_ns)); + } case LogicalTypeId::TIME_TZ: // "00:00:00+1559" from the PG docs, but actually 00:00:00+15:59:59 return Value::TIMETZ(dtime_tz_t(dtime_t(0), dtime_tz_t::MAX_OFFSET)); @@ -270,6 +275,11 @@ Value Value::MinimumValue(const LogicalType &type) { } case LogicalTypeId::ENUM: return Value::ENUM(0, type); + case LogicalTypeId::VARINT: + return Value::VARINT(Varint::VarcharToVarInt( + "-179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540" + "4589535143824642343213268894641827684675467035375169860499105765512820762454900903893289440758685084551339" + "42304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); default: throw InvalidTypeException(type, "MinimumValue requires numeric type"); } @@ -346,8 +356,15 @@ Value Value::MaximumValue(const LogicalType &type) { throw InternalException("Unknown decimal type"); } } - case LogicalTypeId::ENUM: - return Value::ENUM(EnumType::GetSize(type) - 1, type); + case LogicalTypeId::ENUM: { + auto enum_size = EnumType::GetSize(type); + return Value::ENUM(enum_size - (enum_size ? 1 : 0), type); + } + case LogicalTypeId::VARINT: + return Value::VARINT(Varint::VarcharToVarInt( + "1797693134862315708145274237317043567980705675258449965989174768031572607800285387605895586327668781715404" + "5895351438246423432132688946418276846754670353751698604991057655128207624549009038932894407586850845513394" + "2304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); default: throw InvalidTypeException(type, "MaximumValue requires numeric type"); } @@ -399,9 +416,9 @@ Value Value::NegativeInfinity(const LogicalType &type) { } } -Value Value::BOOLEAN(int8_t value) { +Value Value::BOOLEAN(bool value) { Value result(LogicalType::BOOLEAN); - result.value_.boolean = bool(value); + result.value_.boolean = value; result.is_null = false; return result; } @@ -841,6 +858,17 @@ Value Value::BLOB(const_data_ptr_t data, idx_t len) { return result; } +Value Value::VARINT(const_data_ptr_t data, idx_t len) { + return VARINT(string(const_char_ptr_cast(data), len)); +} + +Value Value::VARINT(const string &data) { + Value result(LogicalType::VARINT); + result.is_null = false; + result.value_info_ = make_shared_ptr(data); + return result; +} + Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; @@ -1190,6 +1218,10 @@ template <> timestamp_t Value::GetValue() const { return GetValueInternal(); } +template <> +dtime_tz_t Value::GetValue() const { + return GetValueInternal(); +} template <> DUCKDB_API interval_t Value::GetValue() const { @@ -1210,7 +1242,7 @@ Value Value::Numeric(const LogicalType &type, int64_t value) { switch (type.id()) { case LogicalTypeId::BOOLEAN: D_ASSERT(value == 0 || value == 1); - return Value::BOOLEAN(value ? 1 : 0); + return Value::BOOLEAN(value ? true : false); case LogicalTypeId::TINYINT: D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); return Value::TINYINT((int8_t)value); @@ -1629,6 +1661,16 @@ const vector &StructValue::GetChildren(const Value &value) { return value.value_info_->Get().GetValues(); } +const vector &MapValue::GetChildren(const Value &value) { + if (value.is_null) { + throw InternalException("Calling MapValue::GetChildren on a NULL value"); + } + D_ASSERT(value.type().id() == LogicalTypeId::MAP); + D_ASSERT(value.type().InternalType() == PhysicalType::LIST); + D_ASSERT(value.value_info_); + return value.value_info_->Get().GetValues(); +} + const vector &ListValue::GetChildren(const Value &value) { if (value.is_null) { throw InternalException("Calling ListValue::GetChildren on a NULL value"); diff --git a/src/duckdb/src/common/types/varint.cpp b/src/duckdb/src/common/types/varint.cpp new file mode 100644 index 00000000..121b9a3c --- /dev/null +++ b/src/duckdb/src/common/types/varint.cpp @@ -0,0 +1,295 @@ +#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" +#include + +namespace duckdb { + +void Varint::Verify(const string_t &input) { +#ifdef DEBUG + // Size must be >= 4 + idx_t varint_bytes = input.GetSize(); + if (varint_bytes < 4) { + throw InternalException("Varint number of bytes is invalid, current number of bytes is %d", varint_bytes); + } + // Bytes in header must quantify the number of data bytes + auto varint_ptr = input.GetData(); + bool is_negative = (varint_ptr[0] & 0x80) == 0; + uint32_t number_of_bytes = 0; + char mask = 0x7F; + if (is_negative) { + number_of_bytes |= static_cast(~varint_ptr[0] & mask) << 16 & 0xFF0000; + number_of_bytes |= static_cast(~varint_ptr[1]) << 8 & 0xFF00; + ; + number_of_bytes |= static_cast(~varint_ptr[2]) & 0xFF; + } else { + number_of_bytes |= static_cast(varint_ptr[0] & mask) << 16 & 0xFF0000; + number_of_bytes |= static_cast(varint_ptr[1]) << 8 & 0xFF00; + number_of_bytes |= static_cast(varint_ptr[2]) & 0xFF; + } + if (number_of_bytes != varint_bytes - 3) { + throw InternalException("The number of bytes set in the Varint header: %d bytes. Does not " + "match the number of bytes encountered as the varint data: %d bytes.", + number_of_bytes, varint_bytes - 3); + } + // No bytes between 4 and end can be 0, unless total size == 4 + if (varint_bytes > 4) { + if (is_negative) { + if (~varint_ptr[3] == 0) { + throw InternalException("Invalid top data bytes set to 0 for VARINT values"); + } + } else { + if (varint_ptr[3] == 0) { + throw InternalException("Invalid top data bytes set to 0 for VARINT values"); + } + } + } +#endif +} +void Varint::SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative) { + uint32_t header = static_cast(number_of_bytes); + // Set MSBit of 3rd byte + header |= 0x00800000; + if (is_negative) { + header = ~header; + } + // we ignore MSByte of header. + // write the 3 bytes to blob. + blob[0] = static_cast(header >> 16); + blob[1] = static_cast(header >> 8 & 0xFF); + blob[2] = static_cast(header & 0xFF); +} + +// Creates a blob representing the value 0 +string_t Varint::InitializeVarintZero(Vector &result) { + uint32_t blob_size = 1 + VARINT_HEADER_SIZE; + auto blob = StringVector::EmptyString(result, blob_size); + auto writable_blob = blob.GetDataWriteable(); + SetHeader(writable_blob, 1, false); + writable_blob[3] = 0; + blob.Finalize(); + return blob; +} + +string Varint::InitializeVarintZero() { + uint32_t blob_size = 1 + VARINT_HEADER_SIZE; + string result(blob_size, '0'); + SetHeader(&result[0], 1, false); + result[3] = 0; + return result; +} + +int Varint::CharToDigit(char c) { + return c - '0'; +} + +char Varint::DigitToChar(int digit) { + // FIXME: this would be the proper solution: + // return UnsafeNumericCast(digit + '0'); + return static_cast(digit + '0'); +} + +bool Varint::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, + bool &is_zero) { + // If it's empty we error + if (value.Empty()) { + return false; + } + start_pos = 0; + is_zero = false; + + auto int_value_char = value.GetData(); + end_pos = value.GetSize(); + + // If first character is -, we have a negative number, if + we have a + number + is_negative = int_value_char[0] == '-'; + if (is_negative) { + start_pos++; + } + if (int_value_char[0] == '+') { + start_pos++; + } + // Now lets trim 0s + bool at_least_one_zero = false; + while (start_pos < end_pos && int_value_char[start_pos] == '0') { + start_pos++; + at_least_one_zero = true; + } + if (start_pos == end_pos) { + if (at_least_one_zero) { + // This is a 0 value + is_zero = true; + return true; + } + // This is either a '+' or '-'. Hence, invalid. + return false; + } + idx_t cur_pos = start_pos; + // Verify all is numeric + while (cur_pos < end_pos && std::isdigit(int_value_char[cur_pos])) { + cur_pos++; + } + if (cur_pos < end_pos) { + idx_t possible_end = cur_pos; + // Oh oh, this is not a digit, if it's a . we might be fine, otherwise, this is invalid. + if (int_value_char[cur_pos] == '.') { + cur_pos++; + } else { + return false; + } + + while (cur_pos < end_pos) { + if (std::isdigit(int_value_char[cur_pos])) { + cur_pos++; + } else { + // By now we can only have numbers, otherwise this is invalid. + return false; + } + } + // Floor cast this boy + end_pos = possible_end; + } + return true; +} + +void Varint::GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob) { + if (blob.GetSize() < 4) { + throw InvalidInputException("Invalid blob size."); + } + auto blob_ptr = blob.GetData(); + + // Determine if the number is negative + is_negative = (blob_ptr[0] & 0x80) == 0; + for (idx_t i = 3; i < blob.GetSize(); i++) { + if (is_negative) { + byte_array.push_back(static_cast(~blob_ptr[i])); + } else { + byte_array.push_back(static_cast(blob_ptr[i])); + } + } +} + +string Varint::VarIntToVarchar(const string_t &blob) { + string decimal_string; + vector byte_array; + bool is_negative; + GetByteArray(byte_array, is_negative, blob); + while (!byte_array.empty()) { + string quotient; + uint8_t remainder = 0; + for (uint8_t byte : byte_array) { + int new_value = remainder * 256 + byte; + quotient += DigitToChar(new_value / 10); + remainder = static_cast(new_value % 10); + } + decimal_string += DigitToChar(remainder); + // Remove leading zeros from the quotient + byte_array.clear(); + for (char digit : quotient) { + if (digit != '0' || !byte_array.empty()) { + byte_array.push_back(static_cast(CharToDigit(digit))); + } + } + } + if (is_negative) { + decimal_string += '-'; + } + // Reverse the string to get the correct decimal representation + std::reverse(decimal_string.begin(), decimal_string.end()); + return decimal_string; +} + +string Varint::VarcharToVarInt(const string_t &value) { + idx_t start_pos, end_pos; + bool is_negative, is_zero; + if (!VarcharFormatting(value, start_pos, end_pos, is_negative, is_zero)) { + throw ConversionException("Could not convert string \'%s\' to Varint", value.GetString()); + } + if (is_zero) { + // Return Value 0 + return InitializeVarintZero(); + } + auto int_value_char = value.GetData(); + idx_t actual_size = end_pos - start_pos; + + // we initalize result with space for our header + string result(VARINT_HEADER_SIZE, '0'); + unsafe_vector digits; + + // The max number a uint64_t can represent is 18.446.744.073.709.551.615 + // That has 20 digits + // In the worst case a remainder of a division will be 255, which is 3 digits + // Since the max value is 184, we need to take one more digit out + // Hence we end up with a max of 16 digits supported. + constexpr uint8_t max_digits = 16; + const idx_t number_of_digits = static_cast(std::ceil(static_cast(actual_size) / max_digits)); + + // lets convert the string to a uint64_t vector + idx_t cur_end = end_pos; + for (idx_t i = 0; i < number_of_digits; i++) { + idx_t cur_start = static_cast(start_pos) > static_cast(cur_end - max_digits) + ? start_pos + : cur_end - max_digits; + std::string current_number(int_value_char + cur_start, cur_end - cur_start); + digits.push_back(std::stoull(current_number)); + // move cur_end to more digits down the road + cur_end = cur_end - max_digits; + } + + // Now that we have our uint64_t vector, lets start our division process to figure out the new number and remainder + while (!digits.empty()) { + idx_t digit_idx = digits.size() - 1; + uint8_t remainder = 0; + idx_t digits_size = digits.size(); + for (idx_t i = 0; i < digits_size; i++) { + digits[digit_idx] += static_cast(remainder * pow(10, max_digits)); + remainder = static_cast(digits[digit_idx] % 256); + digits[digit_idx] /= 256; + if (digits[digit_idx] == 0 && digit_idx == digits.size() - 1) { + // we can cap this + digits.pop_back(); + } + digit_idx--; + } + if (is_negative) { + result.push_back(static_cast(~remainder)); + } else { + result.push_back(static_cast(remainder)); + } + } + std::reverse(result.begin() + VARINT_HEADER_SIZE, result.end()); + // Set header after we know the size of the varint + SetHeader(&result[0], result.size() - VARINT_HEADER_SIZE, is_negative); + return result; +} + +bool Varint::VarintToDouble(const string_t &blob, double &result, bool &strict) { + result = 0; + + if (blob.GetSize() < 4) { + throw InvalidInputException("Invalid blob size."); + } + auto blob_ptr = blob.GetData(); + + // Determine if the number is negative + bool is_negative = (blob_ptr[0] & 0x80) == 0; + idx_t byte_pos = 0; + for (idx_t i = blob.GetSize() - 1; i > 2; i--) { + if (is_negative) { + result += static_cast(~blob_ptr[i]) * pow(256, static_cast(byte_pos)); + } else { + result += static_cast(blob_ptr[i]) * pow(256, static_cast(byte_pos)); + } + byte_pos++; + } + + if (is_negative) { + result *= -1; + } + if (!std::isfinite(result)) { + // We throw an error + throw ConversionException("Could not convert varint '%s' to Double", VarIntToVarchar(blob)); + } + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index aa08072a..29e0a955 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -3,25 +3,26 @@ #include "duckdb/common/algorithm.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/fsst.hpp" #include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/pair.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/type_visitor.hpp" +#include "duckdb/common/types/bit.hpp" #include "duckdb/common/types/null_value.hpp" #include "duckdb/common/types/sel_cache.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/value_map.hpp" #include "duckdb/common/types/vector_cache.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" #include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/storage/buffer/buffer_handle.hpp" #include "duckdb/storage/string_uncompressed.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/fsst.hpp" #include "fsst.h" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/value_map.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/types/varint.hpp" #include // strlen() on Solaris @@ -30,7 +31,7 @@ namespace duckdb { UnifiedVectorFormat::UnifiedVectorFormat() : sel(nullptr), data(nullptr) { } -UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept { +UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept : sel(nullptr), data(nullptr) { bool refers_to_self = other.sel == &other.owned_sel; std::swap(sel, other.sel); std::swap(data, other.data); @@ -203,7 +204,9 @@ void Vector::Slice(const Vector &other, idx_t offset, idx_t end) { auto &child_vec = ArrayVector::GetEntry(new_vector); auto &other_child_vec = ArrayVector::GetEntry(other); D_ASSERT(ArrayType::GetSize(GetType()) == ArrayType::GetSize(other.GetType())); - child_vec.Slice(other_child_vec, offset, end); + const auto array_size = ArrayType::GetSize(GetType()); + // We need to slice the child vector with the multiplied offset and end + child_vec.Slice(other_child_vec, offset * array_size, end * array_size); new_vector.validity.Slice(other.validity, offset, end - offset); Reference(new_vector); } else { @@ -378,13 +381,20 @@ void Vector::Resize(idx_t current_size, idx_t new_size) { } // Copy the data buffer to a resized buffer. - auto new_data = make_unsafe_uniq_array(target_size); + auto new_data = make_unsafe_uniq_array_uninitialized(target_size); memcpy(new_data.get(), resize_info_entry.data, old_size); resize_info_entry.buffer->SetData(std::move(new_data)); resize_info_entry.vec.data = resize_info_entry.buffer->GetData(); } } +static bool IsStructOrArrayRecursive(const LogicalType &type) { + return TypeVisitor::Contains(type, [](const LogicalType &type) { + auto physical_type = type.InternalType(); + return (physical_type == PhysicalType::STRUCT || physical_type == PhysicalType::ARRAY); + }); +} + void Vector::SetValue(idx_t index, const Value &val) { if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { // dictionary: apply dictionary and forward to child @@ -392,16 +402,16 @@ void Vector::SetValue(idx_t index, const Value &val) { auto &child = DictionaryVector::Child(*this); return child.SetValue(sel_vector.get_index(index), val); } - if (val.type() != GetType()) { + if (!val.IsNull() && val.type() != GetType()) { SetValue(index, val.DefaultCastAs(GetType())); return; } - D_ASSERT(val.type().InternalType() == GetType().InternalType()); + D_ASSERT(val.IsNull() || (val.type().InternalType() == GetType().InternalType())); validity.EnsureWritable(); validity.Set(index, !val.IsNull()); auto physical_type = GetType().InternalType(); - if (val.IsNull() && physical_type != PhysicalType::STRUCT && physical_type != PhysicalType::ARRAY) { + if (val.IsNull() && !IsStructOrArrayRecursive(GetType())) { // for structs and arrays we still need to set the child-entries to NULL // so we do not bail out yet return; @@ -450,9 +460,12 @@ void Vector::SetValue(idx_t index, const Value &val) { case PhysicalType::INTERVAL: reinterpret_cast(data)[index] = val.GetValueUnsafe(); break; - case PhysicalType::VARCHAR: - reinterpret_cast(data)[index] = StringVector::AddStringOrBlob(*this, StringValue::Get(val)); + case PhysicalType::VARCHAR: { + if (!val.IsNull()) { + reinterpret_cast(data)[index] = StringVector::AddStringOrBlob(*this, StringValue::Get(val)); + } break; + } case PhysicalType::STRUCT: { D_ASSERT(GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR); @@ -475,16 +488,23 @@ void Vector::SetValue(idx_t index, const Value &val) { } case PhysicalType::LIST: { auto offset = ListVector::GetListSize(*this); - auto &val_children = ListValue::GetChildren(val); - if (!val_children.empty()) { - for (idx_t i = 0; i < val_children.size(); i++) { - ListVector::PushBack(*this, val_children[i]); + if (val.IsNull()) { + auto &entry = reinterpret_cast(data)[index]; + ListVector::PushBack(*this, Value()); + entry.length = 1; + entry.offset = offset; + } else { + auto &val_children = ListValue::GetChildren(val); + if (!val_children.empty()) { + for (idx_t i = 0; i < val_children.size(); i++) { + ListVector::PushBack(*this, val_children[i]); + } } + //! now set the pointer + auto &entry = reinterpret_cast(data)[index]; + entry.length = val_children.size(); + entry.offset = offset; } - //! now set the pointer - auto &entry = reinterpret_cast(data)[index]; - entry.length = val_children.size(); - entry.offset = offset; break; } case PhysicalType::ARRAY: { @@ -553,8 +573,10 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { throw InternalException("FSST Vector with non-string datatype found!"); } auto str_compressed = reinterpret_cast(data)[index]; - Value result = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(*vector), str_compressed.GetData(), - str_compressed.GetSize()); + auto decoder = FSSTVector::GetDecoder(*vector); + auto &decompress_buffer = FSSTVector::GetDecompressBuffer(*vector); + Value result = FSSTPrimitives::DecompressValue(decoder, str_compressed.GetData(), str_compressed.GetSize(), + decompress_buffer); return result; } @@ -644,6 +666,10 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { auto str = reinterpret_cast(data)[index]; return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); } + case LogicalTypeId::VARINT: { + auto str = reinterpret_cast(data)[index]; + return Value::VARINT(const_data_ptr_cast(str.GetData()), str.GetSize()); + } case LogicalTypeId::AGGREGATE_STATE: { auto str = reinterpret_cast(data)[index]; return Value::AGGREGATE_STATE(vector->GetType(), const_data_ptr_cast(str.GetData()), str.GetSize()); @@ -755,8 +781,10 @@ string Vector::ToString(idx_t count) const { case VectorType::FSST_VECTOR: { for (idx_t i = 0; i < count; i++) { string_t compressed_string = reinterpret_cast(data)[i]; - Value val = FSSTPrimitives::DecompressValue(FSSTVector::GetDecoder(*this), compressed_string.GetData(), - compressed_string.GetSize()); + auto decoder = FSSTVector::GetDecoder(*this); + auto &decompress_buffer = FSSTVector::GetDecompressBuffer(*this); + Value val = FSSTPrimitives::DecompressValue(decoder, compressed_string.GetData(), + compressed_string.GetSize(), decompress_buffer); retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); } } break; @@ -783,6 +811,44 @@ void Vector::Print(idx_t count) const { Printer::Print(ToString(count)); } +// TODO: add the size of validity masks to this +idx_t Vector::GetAllocationSize(idx_t cardinality) const { + if (!type.IsNested()) { + auto physical_size = GetTypeIdSize(type.InternalType()); + return cardinality * physical_size; + } + auto internal_type = type.InternalType(); + switch (internal_type) { + case PhysicalType::LIST: { + auto physical_size = GetTypeIdSize(type.InternalType()); + auto total_size = physical_size * cardinality; + + auto child_cardinality = ListVector::GetListCapacity(*this); + auto &child_entry = ListVector::GetEntry(*this); + total_size += (child_entry.GetAllocationSize(child_cardinality)); + return total_size; + } + case PhysicalType::ARRAY: { + auto child_cardinality = ArrayVector::GetTotalSize(*this); + + auto &child_entry = ArrayVector::GetEntry(*this); + auto total_size = (child_entry.GetAllocationSize(child_cardinality)); + return total_size; + } + case PhysicalType::STRUCT: { + idx_t total_size = 0; + auto &children = StructVector::GetEntries(*this); + for (auto &child : children) { + total_size += child->GetAllocationSize(cardinality); + } + return total_size; + } + default: + throw NotImplementedException("Vector::GetAllocationSize not implemented for type: %s", type.ToString()); + break; + } +} + string Vector::ToString() const { string retval = VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": (UNKNOWN COUNT) [ "; switch (GetVectorType()) { @@ -860,7 +926,10 @@ void Vector::Flatten(idx_t count) { // constant NULL, set nullmask validity.EnsureWritable(); validity.SetAllInvalid(count); - return; + if (GetType().InternalType() != PhysicalType::STRUCT) { + // for structs we still need to flatten the child vectors as well + return; + } } // non-null constant: have to repeat the constant switch (GetType().InternalType()) { @@ -925,9 +994,12 @@ void Vector::Flatten(idx_t count) { validity.SetAllInvalid(count); // Also invalidate the new child array new_child.validity.SetAllInvalid(count * array_size); + // Recurse + new_child.Flatten(count * array_size); + // TODO: the fast path should exit here, but the part below it is somehow required for correctness // Attach the flattened buffer and return - auxiliary = shared_ptr(flattened_buffer.release()); - return; + // auxiliary = shared_ptr(flattened_buffer.release()); + // return; } // Now we need to "unpack" the child vector. @@ -959,7 +1031,8 @@ void Vector::Flatten(idx_t count) { VectorOperations::Copy(*child_vec, new_child, sel, count * array_size, 0, 0); auxiliary = shared_ptr(flattened_buffer.release()); - } break; + break; + } case PhysicalType::STRUCT: { auto normalified_buffer = make_uniq(); @@ -973,7 +1046,8 @@ void Vector::Flatten(idx_t count) { new_children.push_back(std::move(vector)); } auxiliary = shared_ptr(normalified_buffer.release()); - } break; + break; + } default: throw InternalException("Unimplemented type for VectorOperations::Flatten"); } @@ -982,10 +1056,11 @@ void Vector::Flatten(idx_t count) { case VectorType::SEQUENCE_VECTOR: { int64_t start, increment, sequence_count; SequenceVector::GetSequence(*this, start, increment, sequence_count); + auto seq_count = NumericCast(sequence_count); - buffer = VectorBuffer::CreateStandardVector(GetType()); + buffer = VectorBuffer::CreateStandardVector(GetType(), MaxValue(STANDARD_VECTOR_SIZE, seq_count)); data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, NumericCast(sequence_count), start, increment); + VectorOperations::GenerateSequence(*this, seq_count, start, increment); break; } default: @@ -1105,9 +1180,9 @@ void Vector::Serialize(Serializer &serializer, idx_t count) { UnifiedVectorFormat vdata; ToUnifiedFormat(count, vdata); - const bool all_valid = (count > 0) && !vdata.validity.AllValid(); - serializer.WriteProperty(100, "all_valid", all_valid); - if (all_valid) { + const bool has_validity_mask = (count > 0) && !vdata.validity.AllValid(); + serializer.WriteProperty(100, "has_validity_mask", has_validity_mask); + if (has_validity_mask) { ValidityMask flat_mask(count); flat_mask.Initialize(); for (idx_t i = 0; i < count; ++i) { @@ -1120,7 +1195,7 @@ void Vector::Serialize(Serializer &serializer, idx_t count) { if (TypeIsConstantSize(logical_type.InternalType())) { // constant size type: simple copy idx_t write_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array(write_size); + auto ptr = make_unsafe_uniq_array_uninitialized(write_size); VectorOperations::WriteToStorage(*this, count, ptr.get()); serializer.WriteProperty(102, "data", ptr.get(), write_size); } else { @@ -1150,13 +1225,18 @@ void Vector::Serialize(Serializer &serializer, idx_t count) { auto list_size = ListVector::GetListSize(*this); // serialize the list entries in a flat array - auto entries = make_unsafe_uniq_array(count); + auto entries = make_unsafe_uniq_array_uninitialized(count); auto source_array = UnifiedVectorFormat::GetData(vdata); for (idx_t i = 0; i < count; i++) { auto idx = vdata.sel->get_index(i); auto source = source_array[idx]; - entries[i].offset = source.offset; - entries[i].length = source.length; + if (vdata.validity.RowIsValid(idx)) { + entries[i].offset = source.offset; + entries[i].length = source.length; + } else { + entries[i].offset = 0; + entries[i].length = 0; + } } serializer.WriteProperty(104, "list_size", list_size); serializer.WriteList(105, "entries", count, [&](Serializer::List &list, idx_t i) { @@ -1190,8 +1270,8 @@ void Vector::Deserialize(Deserializer &deserializer, idx_t count) { auto &validity = FlatVector::Validity(*this); validity.Reset(); - const auto has_validity = deserializer.ReadProperty(100, "all_valid"); - if (has_validity) { + const auto has_validity_mask = deserializer.ReadProperty(100, "has_validity_mask"); + if (has_validity_mask) { validity.Initialize(MaxValue(count, STANDARD_VECTOR_SIZE)); deserializer.ReadProperty(101, "validity", data_ptr_cast(validity.GetData()), validity.ValidityMaskSize(count)); } @@ -1199,7 +1279,7 @@ void Vector::Deserialize(Deserializer &deserializer, idx_t count) { if (TypeIsConstantSize(logical_type.InternalType())) { // constant size type: read fixed amount of data auto column_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array(column_size); + auto ptr = make_unsafe_uniq_array_uninitialized(column_size); deserializer.ReadProperty(102, "data", ptr.get(), column_size); VectorOperations::ReadFromStorage(ptr.get(), count, *this); @@ -1381,6 +1461,23 @@ void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) } } + if (type.id() == LogicalTypeId::VARINT) { + switch (vtype) { + case VectorType::FLAT_VECTOR: { + auto &validity = FlatVector::Validity(*vector); + auto strings = FlatVector::GetData(*vector); + for (idx_t i = 0; i < count; i++) { + auto oidx = sel->get_index(i); + if (validity.RowIsValid(oidx)) { + Varint::Verify(strings[oidx]); + } + } + } break; + default: + break; + } + } + if (type.id() == LogicalTypeId::BIT) { switch (vtype) { case VectorType::FLAT_VECTOR: { @@ -1451,6 +1548,7 @@ void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) if (type.InternalType() == PhysicalType::STRUCT) { auto &child_types = StructType::GetChildTypes(type); D_ASSERT(!child_types.empty()); + // create a selection vector of the non-null entries of the struct vector auto &children = StructVector::GetEntries(*vector); D_ASSERT(child_types.size() == children.size()); @@ -1837,7 +1935,7 @@ string_t StringVector::AddString(Vector &vector, string_t data) { vector.auxiliary = make_buffer(); } D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); + auto &string_buffer = vector.auxiliary.get()->Cast(); return string_buffer.AddString(data); } @@ -1851,7 +1949,7 @@ string_t StringVector::AddStringOrBlob(Vector &vector, string_t data) { vector.auxiliary = make_buffer(); } D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); + auto &string_buffer = vector.auxiliary.get()->Cast(); return string_buffer.AddBlob(data); } @@ -1864,7 +1962,7 @@ string_t StringVector::EmptyString(Vector &vector, idx_t len) { vector.auxiliary = make_buffer(); } D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary->Cast(); + auto &string_buffer = vector.auxiliary.get()->Cast(); return string_buffer.EmptyString(len); } @@ -1918,7 +2016,7 @@ string_t FSSTVector::AddCompressedString(Vector &vector, string_t data) { vector.auxiliary = make_buffer(); } D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary->Cast(); + auto &fsst_string_buffer = vector.auxiliary.get()->Cast(); return fsst_string_buffer.AddBlob(data); } @@ -1932,7 +2030,18 @@ void *FSSTVector::GetDecoder(const Vector &vector) { return fsst_string_buffer.GetDecoder(); } -void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder) { +vector &FSSTVector::GetDecompressBuffer(const Vector &vector) { + D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); + if (!vector.auxiliary) { + throw InternalException("GetDecompressBuffer called on FSST Vector without registered buffer"); + } + D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); + auto &fsst_string_buffer = vector.auxiliary->Cast(); + return fsst_string_buffer.GetDecompressBuffer(); +} + +void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder, + const idx_t string_block_limit) { D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); if (!vector.auxiliary) { @@ -1941,7 +2050,7 @@ void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_d D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); auto &fsst_string_buffer = vector.auxiliary->Cast(); - fsst_string_buffer.AddDecoder(duckdb_fsst_decoder); + fsst_string_buffer.AddDecoder(duckdb_fsst_decoder, string_block_limit); } void FSSTVector::SetCount(Vector &vector, idx_t count) { @@ -1980,8 +2089,10 @@ void FSSTVector::DecompressVector(const Vector &src, Vector &dst, idx_t src_offs auto target_idx = dst_offset + i; string_t compressed_string = ldata[source_idx]; if (dst_mask.RowIsValid(target_idx) && compressed_string.GetSize() > 0) { - tdata[target_idx] = FSSTPrimitives::DecompressValue( - FSSTVector::GetDecoder(src), dst, compressed_string.GetData(), compressed_string.GetSize()); + auto decoder = FSSTVector::GetDecoder(src); + auto &decompress_buffer = FSSTVector::GetDecompressBuffer(src); + tdata[target_idx] = FSSTPrimitives::DecompressValue(decoder, dst, compressed_string.GetData(), + compressed_string.GetSize(), decompress_buffer); } else { tdata[target_idx] = string_t(nullptr, 0); } @@ -2190,7 +2301,7 @@ void ListVector::Append(Vector &target, const Vector &source, const SelectionVec } void ListVector::PushBack(Vector &target, const Value &insert) { - auto &target_buffer = target.auxiliary->Cast(); + auto &target_buffer = target.auxiliary.get()->Cast(); target_buffer.PushBack(insert); } diff --git a/src/duckdb/src/common/types/vector_buffer.cpp b/src/duckdb/src/common/types/vector_buffer.cpp index 59bc6f9c..aed61d19 100644 --- a/src/duckdb/src/common/types/vector_buffer.cpp +++ b/src/duckdb/src/common/types/vector_buffer.cpp @@ -67,11 +67,12 @@ VectorListBuffer::VectorListBuffer(const LogicalType &list_type, idx_t initial_c void VectorListBuffer::Reserve(idx_t to_reserve) { if (to_reserve > capacity) { - idx_t new_capacity = NextPowerOfTwo(to_reserve); - if (new_capacity == 0) { - // Overflow: set to_reserve to the maximum value - new_capacity = to_reserve; + if (to_reserve > DConstants::MAX_VECTOR_SIZE) { + // overflow: throw an exception + throw OutOfRangeException("Cannot resize vector to %d rows: maximum allowed vector size is %s", to_reserve, + StringUtil::BytesToHumanReadableString(DConstants::MAX_VECTOR_SIZE)); } + idx_t new_capacity = NextPowerOfTwo(to_reserve); D_ASSERT(new_capacity >= to_reserve); child->Resize(capacity, new_capacity); capacity = new_capacity; diff --git a/src/duckdb/src/common/vector_operations/vector_copy.cpp b/src/duckdb/src/common/vector_operations/vector_copy.cpp index 3823cd39..0880f23b 100644 --- a/src/duckdb/src/common/vector_operations/vector_copy.cpp +++ b/src/duckdb/src/common/vector_operations/vector_copy.cpp @@ -23,7 +23,7 @@ static void TemplatedCopy(const Vector &source, const SelectionVector &sel, Vect } } -static const ValidityMask &CopyValidityMask(const Vector &v) { +static const ValidityMask &ExtractValidityMask(const Vector &v) { switch (v.GetVectorType()) { case VectorType::FLAT_VECTOR: return FlatVector::Validity(v); @@ -35,10 +35,7 @@ static const ValidityMask &CopyValidityMask(const Vector &v) { } void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, - idx_t source_offset, idx_t target_offset) { - D_ASSERT(source_offset <= source_count); - D_ASSERT(source_p.GetType() == target.GetType()); - idx_t copy_count = source_count - source_offset; + idx_t source_offset, idx_t target_offset, idx_t copy_count) { SelectionVector owned_sel; const SelectionVector *sel = &sel_p; @@ -101,25 +98,8 @@ void VectorOperations::Copy(const Vector &source_p, Vector &target, const Select tmask.Set(target_offset + i, valid); } } else { - auto &smask = CopyValidityMask(*source); - if (smask.IsMaskSet() || tmask.IsMaskSet()) { - for (idx_t i = 0; i < copy_count; i++) { - auto idx = sel->get_index(source_offset + i); - - if (smask.RowIsValid(idx)) { - // set valid - if (!tmask.AllValid()) { - tmask.SetValidUnsafe(target_offset + i); - } - } else { - // set invalid - if (tmask.AllValid()) { - tmask.Initialize(); - } - tmask.SetInvalidUnsafe(target_offset + i); - } - } - } + auto &smask = ExtractValidityMask(*source); + tmask.CopySel(smask, *sel, source_offset, target_offset, copy_count); } D_ASSERT(sel); @@ -190,7 +170,7 @@ void VectorOperations::Copy(const Vector &source_p, Vector &target, const Select D_ASSERT(source_children.size() == target_children.size()); for (idx_t i = 0; i < source_children.size(); i++) { VectorOperations::Copy(*source_children[i], *target_children[i], sel_p, source_count, source_offset, - target_offset); + target_offset, copy_count); } break; } @@ -284,6 +264,14 @@ void VectorOperations::Copy(const Vector &source_p, Vector &target, const Select } } +void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, + idx_t source_offset, idx_t target_offset) { + D_ASSERT(source_offset <= source_count); + D_ASSERT(source_p.GetType() == target.GetType()); + idx_t copy_count = source_count - source_offset; + VectorOperations::Copy(source_p, target, sel_p, source_count, source_offset, target_offset, copy_count); +} + void VectorOperations::Copy(const Vector &source, Vector &target, idx_t source_count, idx_t source_offset, idx_t target_offset) { VectorOperations::Copy(source, target, *FlatVector::IncrementalSelectionVector(), source_count, source_offset, diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp index 590642a4..e6ef5f5f 100644 --- a/src/duckdb/src/common/vector_operations/vector_hash.cpp +++ b/src/duckdb/src/common/vector_operations/vector_hash.cpp @@ -206,6 +206,9 @@ static inline void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionV for (idx_t i = 0; i < count; i++) { auto lidx = idata.sel->get_index(i); if (idata.validity.RowIsValid(lidx)) { + if (FIRST_HASH) { + hdata[i] = 0; + } for (idx_t j = 0; j < array_size; j++) { auto offset = lidx * array_size + j; hdata[i] = CombineHashScalar(hdata[i], chdata[offset]); @@ -233,6 +236,9 @@ static inline void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionV VectorOperations::Hash(dict_vec, array_hashes, array_size); auto ahdata = FlatVector::GetData(array_hashes); + if (FIRST_HASH) { + hdata[ridx] = 0; + } // Combine the hashes of the array for (idx_t j = 0; j < array_size; j++) { hdata[ridx] = CombineHashScalar(hdata[ridx], ahdata[j]); diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp index 3bc099a2..74892a4e 100644 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -13,15 +13,15 @@ unique_ptr VirtualFileSystem::OpenFile(const string &path, FileOpenF optional_ptr opener) { auto compression = flags.Compression(); if (compression == FileCompressionType::AUTO_DETECT) { - // auto detect compression settings based on file name + // auto-detect compression settings based on file name auto lower_path = StringUtil::Lower(path); if (StringUtil::EndsWith(lower_path, ".tmp")) { // strip .tmp lower_path = lower_path.substr(0, lower_path.length() - 4); } - if (StringUtil::EndsWith(lower_path, ".gz")) { + if (IsFileCompressed(path, FileCompressionType::GZIP)) { compression = FileCompressionType::GZIP; - } else if (StringUtil::EndsWith(lower_path, ".zst")) { + } else if (IsFileCompressed(path, FileCompressionType::ZSTD)) { compression = FileCompressionType::ZSTD; } else { compression = FileCompressionType::UNCOMPRESSED; diff --git a/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp index 6b8dd0d7..13d33220 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/approx_count.cpp @@ -1,113 +1,83 @@ -#include "duckdb/core_functions/aggregate/distributive_functions.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/types/hash.hpp" #include "duckdb/common/types/hyperloglog.hpp" +#include "duckdb/core_functions/aggregate/distributive_functions.hpp" #include "duckdb/function/function_set.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "hyperloglog.hpp" namespace duckdb { +// Algorithms from +// "New cardinality estimation algorithms for HyperLogLog sketches" +// Otmar Ertl, arXiv:1702.01284 struct ApproxDistinctCountState { - ApproxDistinctCountState() : log(nullptr) { - } - ~ApproxDistinctCountState() { - if (log) { - delete log; - } - } - - HyperLogLog *log; + HyperLogLog hll; }; struct ApproxCountDistinctFunction { template static void Initialize(STATE &state) { - state.log = nullptr; + new (&state) STATE(); } template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.log) { - return; - } - if (!target.log) { - target.log = new HyperLogLog(); - } - D_ASSERT(target.log); - D_ASSERT(source.log); - auto new_log = target.log->MergePointer(*source.log); - delete target.log; - target.log = new_log; + target.hll.Merge(source.hll); } template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.log) { - target = UnsafeNumericCast(state.log->Count()); - } else { - target = 0; - } + target = UnsafeNumericCast(state.hll.Count()); } static bool IgnoreNull() { return true; } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.log) { - delete state.log; - state.log = nullptr; - } - } }; static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state, idx_t count) { D_ASSERT(input_count == 1); - - auto agg_state = reinterpret_cast(state); - if (!agg_state->log) { - agg_state->log = new HyperLogLog(); - } - - UnifiedVectorFormat vdata; - inputs[0].ToUnifiedFormat(count, vdata); + auto &input = inputs[0]; if (count > STANDARD_VECTOR_SIZE) { throw InternalException("ApproxCountDistinct - count must be at most vector size"); } - uint64_t indices[STANDARD_VECTOR_SIZE]; - uint8_t counts[STANDARD_VECTOR_SIZE]; - HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); - agg_state->log->AddToLog(vdata, count, indices, counts); + Vector hash_vec(LogicalType::HASH, count); + VectorOperations::Hash(input, hash_vec, count); + + auto agg_state = reinterpret_cast(state); + agg_state->hll.Update(input, hash_vec, count); } static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { D_ASSERT(input_count == 1); + auto &input = inputs[0]; + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + if (count > STANDARD_VECTOR_SIZE) { + throw InternalException("ApproxCountDistinct - count must be at most vector size"); + } + Vector hash_vec(LogicalType::HASH, count); + VectorOperations::Hash(input, hash_vec, count); UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetDataNoConst(sdata); + const auto states = UnifiedVectorFormat::GetDataNoConst(sdata); + UnifiedVectorFormat hdata; + hash_vec.ToUnifiedFormat(count, hdata); + const auto *hashes = UnifiedVectorFormat::GetData(hdata); for (idx_t i = 0; i < count; i++) { - auto agg_state = states[sdata.sel->get_index(i)]; - if (!agg_state->log) { - agg_state->log = new HyperLogLog(); + if (idata.validity.RowIsValid(idata.sel->get_index(i))) { + auto agg_state = states[sdata.sel->get_index(i)]; + const auto hash = hashes[hdata.sel->get_index(i)]; + agg_state->hll.InsertElement(hash); } } - - UnifiedVectorFormat vdata; - inputs[0].ToUnifiedFormat(count, vdata); - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - uint64_t indices[STANDARD_VECTOR_SIZE]; - uint8_t counts[STANDARD_VECTOR_SIZE]; - HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count); - HyperLogLog::AddToLogs(vdata, count, indices, counts, reinterpret_cast(states), sdata.sel); } AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) { @@ -117,30 +87,13 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) ApproxCountDistinctUpdateFunction, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, - ApproxCountDistinctSimpleUpdateFunction, nullptr, - AggregateFunction::StateDestroy); + ApproxCountDistinctSimpleUpdateFunction); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } -AggregateFunctionSet ApproxCountDistinctFun::GetFunctions() { - AggregateFunctionSet approx_count("approx_count_distinct"); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UTINYINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::USMALLINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UINTEGER)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UBIGINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UHUGEINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TINYINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::SMALLINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BIGINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::HUGEINT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::FLOAT)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::DOUBLE)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP_TZ)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BLOB)); - approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150))); - return approx_count; +AggregateFunction ApproxCountDistinctFun::GetFunction() { + return GetApproxCountDistinctFunction(LogicalType::ANY); } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp index c39b0599..5120bc79 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/arg_min_max.cpp @@ -7,6 +7,8 @@ #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" +#include "duckdb/core_functions/aggregate/minmax_n_helpers.hpp" namespace duckdb { @@ -42,11 +44,6 @@ void ArgMinMaxStateBase::CreateValue(string_t &value) { value = string_t(uint32_t(0)); } -template <> -void ArgMinMaxStateBase::CreateValue(Vector *&value) { - value = nullptr; -} - template <> void ArgMinMaxStateBase::DestroyValue(string_t &value) { if (!value.IsInlined()) { @@ -54,12 +51,6 @@ void ArgMinMaxStateBase::DestroyValue(string_t &value) { } } -template <> -void ArgMinMaxStateBase::DestroyValue(Vector *&value) { - delete value; - value = nullptr; -} - template <> void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value) { DestroyValue(target); @@ -104,7 +95,6 @@ struct ArgMinMaxState : public ArgMinMaxStateBase { template struct ArgMinMaxBase { - template static void Initialize(STATE &state) { new (&state) STATE; @@ -164,7 +154,7 @@ struct ArgMinMaxBase { if (!state.is_initialized || state.arg_null) { finalize_data.ReturnNull(); } else { - STATE::template ReadValue(finalize_data.result, state.arg, target); + STATE::template ReadValue(finalize_data.result, state.arg, target); } } @@ -175,7 +165,7 @@ struct ArgMinMaxBase { static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { - ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type, false); + ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); } function.arguments[0] = arguments[0]->return_type; function.return_type = arguments[0]->return_type; @@ -183,38 +173,55 @@ struct ArgMinMaxBase { } }; -template -struct VectorArgMinMaxBase : ArgMinMaxBase { - template - static void AssignVector(STATE &state, Vector &arg, bool arg_null, const idx_t idx) { - if (!state.arg) { - state.arg = new Vector(arg.GetType(), 1); - state.arg->SetVectorType(VectorType::CONSTANT_VECTOR); - } - state.arg_null = arg_null; - if (!arg_null) { - sel_t selv = UnsafeNumericCast(idx); - SelectionVector sel(&selv); - VectorOperations::Copy(arg, *state.arg, sel, 1, 0, 0); - } +struct SpecializedGenericArgMinMaxState { + static bool CreateExtraState(idx_t count) { + // nop extra state + return false; } + static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) { + by.ToUnifiedFormat(count, result); + } +}; + +template +struct GenericArgMinMaxState { + static Vector CreateExtraState(idx_t count) { + return Vector(LogicalType::BLOB, count); + } + + static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { + OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); + CreateSortKeyHelpers::CreateSortKey(by, count, modifiers, extra_state); + extra_state.ToUnifiedFormat(count, result); + } +}; + +template +struct VectorArgMinMaxBase : ArgMinMaxBase { template static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { auto &arg = inputs[0]; UnifiedVectorFormat adata; arg.ToUnifiedFormat(count, adata); + using ARG_TYPE = typename STATE::ARG_TYPE; using BY_TYPE = typename STATE::BY_TYPE; auto &by = inputs[1]; UnifiedVectorFormat bdata; - by.ToUnifiedFormat(count, bdata); + auto extra_state = UPDATE_TYPE::CreateExtraState(count); + UPDATE_TYPE::PrepareData(by, count, extra_state, bdata); const auto bys = UnifiedVectorFormat::GetData(bdata); UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = (STATE **)sdata.data; + STATE *last_state = nullptr; + sel_t assign_sel[STANDARD_VECTOR_SIZE]; + idx_t assign_count = 0; + + auto states = UnifiedVectorFormat::GetData(sdata); for (idx_t i = 0; i < count; i++) { const auto bidx = bdata.sel->get_index(i); if (!bdata.validity.RowIsValid(bidx)) { @@ -230,16 +237,42 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { const auto sidx = sdata.sel->get_index(i); auto &state = *states[sidx]; - if (!state.is_initialized) { + if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { STATE::template AssignValue(state.value, bval); - AssignVector(state, arg, arg_null, i); + state.arg_null = arg_null; + // micro-adaptivity: it is common we overwrite the same state repeatedly + // e.g. when running arg_max(val, ts) and ts is sorted in ascending order + // this check essentially says: + // "if we are overriding the same state as the last row, the last write was pointless" + // hence we skip the last write altogether + if (!arg_null) { + if (&state == last_state) { + assign_count--; + } + assign_sel[assign_count++] = UnsafeNumericCast(i); + last_state = &state; + } state.is_initialized = true; - - } else if (COMPARATOR::template Operation(bval, state.value)) { - STATE::template AssignValue(state.value, bval); - AssignVector(state, arg, arg_null, i); } } + if (assign_count == 0) { + // no need to assign anything: nothing left to do + return; + } + Vector sort_key(LogicalType::BLOB); + auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); + // slice with a selection vector and generate sort keys + SelectionVector sel(assign_sel); + Vector sliced_input(arg, sel, assign_count); + CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); + auto sort_key_data = FlatVector::GetData(sort_key); + + // now assign sort keys + for (idx_t i = 0; i < assign_count; i++) { + const auto sidx = sdata.sel->get_index(sel.get_index(i)); + auto &state = *states[sidx]; + STATE::template AssignValue(state.arg, sort_key_data[i]); + } } template @@ -248,8 +281,12 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { return; } if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value); - AssignVector(target, *source.arg, source.arg_null, 0); + STATE::template AssignValue(target.value, source.value); + target.arg_null = source.arg_null; + if (!target.arg_null) { + STATE::template AssignValue(target.arg, source.arg); + ; + } target.is_initialized = true; } } @@ -259,7 +296,8 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { if (!state.is_initialized || state.arg_null) { finalize_data.ReturnNull(); } else { - VectorOperations::Copy(*state.arg, finalize_data.result, 1, 0, finalize_data.result_idx); + CreateSortKeyHelpers::DecodeSortKey(state.arg, finalize_data.result, finalize_data.result_idx, + OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST)); } } @@ -409,7 +447,17 @@ void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); } -template +template +void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { + using STATE = ArgMinMaxState; + fun.AddFunction( + AggregateFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, OP::template Update, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, + nullptr, OP::Bind, AggregateFunction::StateDestroy)); +} + +template static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { using OP = ArgMinMaxBase; AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); @@ -426,31 +474,220 @@ static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { AddDecimalArgMinMaxFunctionBy(fun, by_type); } - using VECTOR_OP = VectorArgMinMaxBase; - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + using VECTOR_OP = VectorArgMinMaxBase; + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + + // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest + using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; + AddGenericArgMinMaxFunction(fun); +} + +//------------------------------------------------------------------------------ +// ArgMinMax(N) Function +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ +// State +//------------------------------------------------------------------------------ + +template +class ArgMinMaxNState { +public: + using VAL_TYPE = A; + using ARG_TYPE = B; + + using V = typename VAL_TYPE::TYPE; + using K = typename ARG_TYPE::TYPE; + + BinaryAggregateHeap heap; + + bool is_initialized = false; + void Initialize(idx_t nval) { + heap.Initialize(nval); + is_initialized = true; + } +}; + +//------------------------------------------------------------------------------ +// Operation +//------------------------------------------------------------------------------ +template +static void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, + idx_t count) { + + auto &val_vector = inputs[0]; + auto &arg_vector = inputs[1]; + auto &n_vector = inputs[2]; + + UnifiedVectorFormat val_format; + UnifiedVectorFormat arg_format; + UnifiedVectorFormat n_format; + UnifiedVectorFormat state_format; + + auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); + auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); + + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); + + n_vector.ToUnifiedFormat(count, n_format); + state_vector.ToUnifiedFormat(count, state_format); + + auto states = UnifiedVectorFormat::GetData(state_format); + + for (idx_t i = 0; i < count; i++) { + const auto arg_idx = arg_format.sel->get_index(i); + const auto val_idx = val_format.sel->get_index(i); + if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { + continue; + } + const auto state_idx = state_format.sel->get_index(i); + auto &state = *states[state_idx]; + + // Initialize the heap if necessary and add the input to the heap + if (!state.is_initialized) { + static constexpr int64_t MAX_N = 1000000; + const auto nidx = n_format.sel->get_index(i); + if (!n_format.validity.RowIsValid(nidx)) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value cannot be NULL"); + } + const auto nval = UnifiedVectorFormat::GetData(n_format)[nidx]; + if (nval <= 0) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be > 0"); + } + if (nval >= MAX_N) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be < %d", MAX_N); + } + state.Initialize(UnsafeNumericCast(nval)); + } + + // Now add the input to the heap + auto arg_val = STATE::ARG_TYPE::Create(arg_format, arg_idx); + auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx); + + state.heap.Insert(aggr_input.allocator, arg_val, val_val); + } +} + +//------------------------------------------------------------------------------ +// Bind +//------------------------------------------------------------------------------ +template +static void SpecializeArgMinMaxNFunction(AggregateFunction &function) { + using STATE = ArgMinMaxNState; + using OP = MinMaxNOperation; + + function.state_size = AggregateFunction::StateSize; + function.initialize = AggregateFunction::StateInitialize; + function.combine = AggregateFunction::StateCombine; + function.destructor = AggregateFunction::StateDestroy; + + function.finalize = MinMaxNOperation::Finalize; + function.update = ArgMinMaxNUpdate; } +template +static void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNFunction(function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + default: + SpecializeArgMinMaxNFunction(function); + break; + } +} + +template +static void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { + switch (val_type) { + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNFunction(arg_type, function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + default: + SpecializeArgMinMaxNFunction(arg_type, function); + break; + } +} + +template +unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + + const auto val_type = arguments[0]->return_type.InternalType(); + const auto arg_type = arguments[1]->return_type.InternalType(); + + // Specialize the function based on the input types + SpecializeArgMinMaxNFunction(val_type, arg_type, function); + + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return nullptr; +} + +template +static void AddArgMinMaxNFunction(AggregateFunctionSet &set) { + AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, ArgMinMaxNBind); + + return set.AddFunction(function); +} + +//------------------------------------------------------------------------------ +// Function Registration +//------------------------------------------------------------------------------ + AggregateFunctionSet ArgMinFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMaxFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMinNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun); return fun; } AggregateFunctionSet ArgMaxNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun); return fun; } diff --git a/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp index 2d57a4f5..af305635 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/bitagg.cpp @@ -53,10 +53,10 @@ struct BitwiseOperation { template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { if (!state.is_set) { - OP::template Assign(state, input); + OP::template Assign(state, input); state.is_set = true; } else { - OP::template Execute(state, input); + OP::template Execute(state, input); } } @@ -79,10 +79,10 @@ struct BitwiseOperation { } if (!target.is_set) { // target is NULL, use source value directly. - OP::template Assign(target, source.value); + OP::template Assign(target, source.value); target.is_set = true; } else { - OP::template Execute(target, source.value); + OP::template Execute(target, source.value); } } diff --git a/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp b/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp index e66f3078..426d4498 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/entropy.cpp @@ -50,11 +50,12 @@ struct EntropyFunctionBase { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - double count = state.count; + double count = static_cast(state.count); if (state.distinct) { double entropy = 0; for (auto &val : *state.distinct) { - entropy += (val.second / count) * log2(count / val.second); + double val_sec = static_cast(val.second); + entropy += (val_sec / count) * log2(count / val_sec); } target = entropy; } else { diff --git a/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp b/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp index d3a5dd49..9c8db50b 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/minmax.cpp @@ -8,6 +8,8 @@ #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression_binder.hpp" #include "duckdb/function/function_binder.hpp" +#include "duckdb/core_functions/aggregate/sort_key_helpers.hpp" +#include "duckdb/core_functions/aggregate/minmax_n_helpers.hpp" namespace duckdb { @@ -147,28 +149,47 @@ struct MaxOperation : public NumericMinMaxBase { } }; -struct StringMinMaxBase : public MinMaxBase { - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.isset && !state.value.IsInlined()) { - delete[] state.value.GetData(); +struct MinMaxStringState : MinMaxState { + void Destroy() { + if (isset && !value.IsInlined()) { + delete[] value.GetData(); } } - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - Destroy(state, input_data); + void Assign(string_t input) { if (input.IsInlined()) { - state.value = input; + // inlined string - we can directly store it into the string_t without having to allocate anything + Destroy(); + value = input; } else { - // non-inlined string, need to allocate space for it + // non-inlined string, need to allocate space for it somehow auto len = input.GetSize(); - auto ptr = new char[len]; + char *ptr; + if (!isset || value.GetSize() < len) { + // we cannot fit this into the current slot - destroy it and re-allocate + Destroy(); + ptr = new char[len]; + } else { + // this fits into the current slot - take over the pointer + ptr = value.GetDataWriteable(); + } memcpy(ptr, input.GetData(), len); - state.value = string_t(ptr, UnsafeNumericCast(len)); + value = string_t(ptr, UnsafeNumericCast(len)); } } +}; + +struct StringMinMaxBase : public MinMaxBase { + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.Destroy(); + } + + template + static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + state.Assign(input); + } template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { @@ -213,302 +234,57 @@ struct MaxOperationString : public StringMinMaxBase { } }; -template -static bool TemplatedOptimumType(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - lidx = lvdata.sel->get_index(lidx); - ridx = rvdata.sel->get_index(ridx); - - auto ldata = UnifiedVectorFormat::GetData(lvdata); - auto rdata = UnifiedVectorFormat::GetData(rvdata); - - auto &lval = ldata[lidx]; - auto &rval = rdata[ridx]; - - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - - return OP::Operation(lval, rval, lnull, rnull); -} - -template -static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); - -template -static bool TemplatedOptimumStruct(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); - -template -static bool TemplatedOptimumArray(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount); - -template -static bool TemplatedOptimumValue(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - D_ASSERT(left.GetType() == right.GetType()); - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT16: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT32: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT64: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT8: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT16: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT32: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT64: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INT128: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::UINT128: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::FLOAT: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::DOUBLE: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::INTERVAL: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::VARCHAR: - return TemplatedOptimumType(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::LIST: - return TemplatedOptimumList(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::STRUCT: - return TemplatedOptimumStruct(left, lidx, lcount, right, ridx, rcount); - case PhysicalType::ARRAY: - return TemplatedOptimumArray(left, lidx, lcount, right, ridx, rcount); - default: - throw InternalException("Invalid type for distinct comparison"); - } -} - -template -static bool TemplatedOptimumStruct(Vector &left, idx_t lidx_p, idx_t lcount, Vector &right, idx_t ridx_p, - idx_t rcount) { - // STRUCT dictionaries apply to all the children - // so map the indexes first - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - idx_t lidx = lvdata.sel->get_index(lidx_p); - idx_t ridx = rvdata.sel->get_index(ridx_p); - - // DISTINCT semantics are in effect for nested types - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - if (lnull || rnull) { - return OP::Operation(0, 0, lnull, rnull); - } - - auto &lchildren = StructVector::GetEntries(left); - auto &rchildren = StructVector::GetEntries(right); - - D_ASSERT(lchildren.size() == rchildren.size()); - for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { - auto &lchild = *lchildren[col_no]; - auto &rchild = *rchildren[col_no]; - - // Strict comparisons use the OP for definite - if (TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { - return true; - } - - if (col_no == lchildren.size() - 1) { - break; - } - - // Strict comparisons use IS NOT DISTINCT for possible - if (!TemplatedOptimumValue(lchild, lidx_p, lcount, rchild, ridx_p, rcount)) { - return false; - } - } - - return false; -} - -template -static bool TemplatedOptimumList(Vector &left, idx_t lidx, idx_t lcount, Vector &right, idx_t ridx, idx_t rcount) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - // Update the indexes and vector sizes for recursion. - lidx = lvdata.sel->get_index(lidx); - ridx = rvdata.sel->get_index(ridx); - - lcount = ListVector::GetListSize(left); - rcount = ListVector::GetListSize(right); - - // DISTINCT semantics are in effect for nested types - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - if (lnull || rnull) { - return OP::Operation(0, 0, lnull, rnull); - } - - auto &lchild = ListVector::GetEntry(left); - auto &rchild = ListVector::GetEntry(right); - - auto ldata = UnifiedVectorFormat::GetData(lvdata); - auto rdata = UnifiedVectorFormat::GetData(rvdata); - - auto &lval = ldata[lidx]; - auto &rval = rdata[ridx]; - - for (idx_t pos = 0;; ++pos) { - // Tie-breaking uses the OP - if (pos == lval.length || pos == rval.length) { - return OP::Operation(lval.length, rval.length, false, false); - } - - // Strict comparisons use the OP for definite - lidx = lval.offset + pos; - ridx = rval.offset + pos; - if (TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { - return true; - } - - // Strict comparisons use IS NOT DISTINCT for possible - if (!TemplatedOptimumValue(lchild, lidx, lcount, rchild, ridx, rcount)) { - return false; - } - } - - return false; -} - -// FIXME: We should try to unify this with TemplatedOptimumList -template -static bool TemplatedOptimumArray(Vector &left, idx_t lidx_p, idx_t lcount, Vector &right, idx_t ridx_p, idx_t rcount) { - // so map the indexes first - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(lcount, lvdata); - right.ToUnifiedFormat(rcount, rvdata); - - idx_t lidx = lvdata.sel->get_index(lidx_p); - idx_t ridx = rvdata.sel->get_index(ridx_p); - - // DISTINCT semantics are in effect for nested types - auto lnull = !lvdata.validity.RowIsValid(lidx); - auto rnull = !rvdata.validity.RowIsValid(ridx); - if (lnull || rnull) { - return OP::Operation(0, 0, lnull, rnull); - } - - auto &lchild = ArrayVector::GetEntry(left); - auto &rchild = ArrayVector::GetEntry(right); - auto left_array_size = ArrayType::GetSize(left.GetType()); - auto right_array_size = ArrayType::GetSize(right.GetType()); - - D_ASSERT(left_array_size == right_array_size); - - auto lchild_count = lcount * left_array_size; - auto rchild_count = rcount * right_array_size; - - for (idx_t elem_idx = 0; elem_idx < left_array_size; elem_idx++) { - auto left_elem_idx = lidx * left_array_size + elem_idx; - auto right_elem_idx = ridx * right_array_size + elem_idx; - - // Strict comparisons use the OP for definite - if (TemplatedOptimumValue(lchild, left_elem_idx, lchild_count, rchild, right_elem_idx, rchild_count)) { - return true; - } - - // Strict comparisons use IS NOT DISTINCT for possible - if (!TemplatedOptimumValue(lchild, left_elem_idx, lchild_count, rchild, right_elem_idx, - rchild_count)) { - return false; - } - } - return false; -} - -struct VectorMinMaxState { - Vector *value; -}; - +template struct VectorMinMaxBase { + static constexpr OrderType ORDER_TYPE = ORDER_TYPE_TEMPLATED; + static bool IgnoreNull() { return true; } template static void Initialize(STATE &state) { - state.value = nullptr; + state.isset = false; } template static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.value) { - delete state.value; - } - state.value = nullptr; - } - - template - static void Assign(STATE &state, Vector &input, const idx_t idx) { - if (!state.value) { - state.value = new Vector(input.GetType()); - state.value->SetVectorType(VectorType::CONSTANT_VECTOR); - } - sel_t selv = UnsafeNumericCast(idx); - SelectionVector sel(&selv); - VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); + state.Destroy(); } - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - Assign(state, input, idx); + template + static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + state.Assign(input); } - template - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { - auto &input = inputs[0]; - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto states = (STATE **)sdata.data; - for (idx_t i = 0; i < count; i++) { - const auto idx = idata.sel->get_index(i); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.value) { - Assign(state, input, i); - } else { - OP::template Execute(state, input, i, count); - } + template + static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { + if (!state.isset) { + Assign(state, input, input_data); + state.isset = true; + return; + } + if (LessThan::Operation(input, state.value)) { + Assign(state, input, input_data); } } template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.value) { + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.isset) { + // source is NULL, nothing to do return; - } else if (!target.value) { - Assign(target, *source.value, 0); - } else { - OP::template Execute(target, *source.value, 0, 1); } + OP::template Execute(target, source.value, input_data); } template static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.value) { + if (!state.isset) { finalize_data.ReturnNull(); } else { - VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); + CreateSortKeyHelpers::DecodeSortKey(state.value, finalize_data.result, finalize_data.result_idx, + OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST)); } } @@ -520,56 +296,17 @@ struct VectorMinMaxBase { } }; -struct MinOperationVector : public VectorMinMaxBase { - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { - Assign(state, input, idx); - } - } -}; +struct MinOperationVector : VectorMinMaxBase {}; -struct MaxOperationVector : public VectorMinMaxBase { - template - static void Execute(STATE &state, Vector &input, const idx_t idx, const idx_t count) { - if (TemplatedOptimumValue(input, idx, count, *state.value, 0, 1)) { - Assign(state, input, idx); - } - } -}; - -template -unique_ptr BindDecimalMinMax(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - auto name = function.name; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - function = GetUnaryAggregate(LogicalType::SMALLINT); - break; - case PhysicalType::INT32: - function = GetUnaryAggregate(LogicalType::INTEGER); - break; - case PhysicalType::INT64: - function = GetUnaryAggregate(LogicalType::BIGINT); - break; - default: - function = GetUnaryAggregate(LogicalType::HUGEINT); - break; - } - function.name = std::move(name); - function.arguments[0] = decimal_type; - function.return_type = decimal_type; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return nullptr; -} +struct MaxOperationVector : VectorMinMaxBase {}; template static AggregateFunction GetMinMaxFunction(const LogicalType &type) { return AggregateFunction( - {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, AggregateFunction::StateDestroy); + {type}, LogicalType::BLOB, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateDestroy); } template @@ -577,12 +314,12 @@ static AggregateFunction GetMinMaxOperator(const LogicalType &type) { auto internal_type = type.InternalType(); switch (internal_type) { case PhysicalType::VARCHAR: - return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, OP_STRING>( - type.id(), type.id()); + return AggregateFunction::UnaryAggregateDestructor(type.id(), + type.id()); case PhysicalType::LIST: case PhysicalType::STRUCT: case PhysicalType::ARRAY: - return GetMinMaxFunction(type); + return GetMinMaxFunction(type); default: return GetUnaryAggregate(type); } @@ -591,7 +328,6 @@ static AggregateFunction GetMinMaxOperator(const LogicalType &type) { template unique_ptr BindMinMax(ClientContext &context, AggregateFunction &function, vector> &arguments) { - if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { auto str_collation = StringType::GetCollation(arguments[0]->return_type); if (!str_collation.empty()) { @@ -616,7 +352,7 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f // Create a copied child and PushCollation for it. arguments.push_back(arguments[0]->Copy()); - ExpressionBinder::PushCollation(context, arguments[1], arguments[0]->return_type, false); + ExpressionBinder::PushCollation(context, arguments[1], arguments[0]->return_type); // Bind function like arg_min/arg_max. function.arguments[0] = arguments[0]->return_type; @@ -626,6 +362,9 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f } auto input_type = arguments[0]->return_type; + if (input_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } auto name = std::move(function.name); function = GetMinMaxOperator(input_type); function.name = std::move(name); @@ -638,22 +377,171 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f } template -static void AddMinMaxOperator(AggregateFunctionSet &set) { - set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindDecimalMinMax)); - set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindMinMax)); +static AggregateFunction GetMinMaxOperator(string name) { + return AggregateFunction(std::move(name), {LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, BindMinMax); +} + +AggregateFunction MinFun::GetFunction() { + return GetMinMaxOperator("min"); +} + +AggregateFunction MaxFun::GetFunction() { + return GetMinMaxOperator("max"); +} + +//--------------------------------------------------- +// MinMaxN +//--------------------------------------------------- + +template +class MinMaxNState { +public: + using VAL_TYPE = A; + using T = typename VAL_TYPE::TYPE; + + UnaryAggregateHeap heap; + bool is_initialized = false; + + void Initialize(idx_t nval) { + heap.Initialize(nval); + is_initialized = true; + } + + static const T &GetValue(const T &val) { + return val; + } +}; + +template +static void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, + idx_t count) { + + auto &val_vector = inputs[0]; + auto &n_vector = inputs[1]; + + UnifiedVectorFormat val_format; + UnifiedVectorFormat n_format; + UnifiedVectorFormat state_format; + ; + auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); + + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + + n_vector.ToUnifiedFormat(count, n_format); + state_vector.ToUnifiedFormat(count, state_format); + + auto states = UnifiedVectorFormat::GetData(state_format); + + for (idx_t i = 0; i < count; i++) { + const auto val_idx = val_format.sel->get_index(i); + if (!val_format.validity.RowIsValid(val_idx)) { + continue; + } + const auto state_idx = state_format.sel->get_index(i); + auto &state = *states[state_idx]; + + // Initialize the heap if necessary and add the input to the heap + if (!state.is_initialized) { + static constexpr int64_t MAX_N = 1000000; + const auto nidx = n_format.sel->get_index(i); + if (!n_format.validity.RowIsValid(nidx)) { + throw InvalidInputException("Invalid input for MIN/MAX: n value cannot be NULL"); + } + const auto nval = UnifiedVectorFormat::GetData(n_format)[nidx]; + if (nval <= 0) { + throw InvalidInputException("Invalid input for MIN/MAX: n value must be > 0"); + } + if (nval >= MAX_N) { + throw InvalidInputException("Invalid input for MIN/MAX: n value must be < %d", MAX_N); + } + state.Initialize(UnsafeNumericCast(nval)); + } + + // Now add the input to the heap + auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx); + state.heap.Insert(aggr_input.allocator, val_val); + } +} + +template +static void SpecializeMinMaxNFunction(AggregateFunction &function) { + using STATE = MinMaxNState; + using OP = MinMaxNOperation; + + function.state_size = AggregateFunction::StateSize; + function.initialize = AggregateFunction::StateInitialize; + function.combine = AggregateFunction::StateCombine; + function.destructor = AggregateFunction::StateDestroy; + + function.finalize = MinMaxNOperation::Finalize; + function.update = MinMaxNUpdate; } +template +static void SpecializeMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { + case PhysicalType::VARCHAR: + SpecializeMinMaxNFunction(function); + break; + case PhysicalType::INT32: + SpecializeMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeMinMaxNFunction, COMPARATOR>(function); + break; + default: + SpecializeMinMaxNFunction(function); + break; + } +} + +template +unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + + const auto val_type = arguments[0]->return_type.InternalType(); + + // Specialize the function based on the input types + SpecializeMinMaxNFunction(val_type, function); + + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return nullptr; +} + +template +static AggregateFunction GetMinMaxNFunction() { + return AggregateFunction({LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, MinMaxNBind, nullptr); +} + +//--------------------------------------------------- +// Function Registration +//---------------------------------------------------s + AggregateFunctionSet MinFun::GetFunctions() { AggregateFunctionSet min("min"); - AddMinMaxOperator(min); + min.AddFunction(GetFunction()); + min.AddFunction(GetMinMaxNFunction()); return min; } AggregateFunctionSet MaxFun::GetFunctions() { AggregateFunctionSet max("max"); - AddMinMaxOperator(max); + max.AddFunction(GetFunction()); + max.AddFunction(GetMinMaxNFunction()); return max; } diff --git a/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp index 61c996c0..3aa254e3 100644 --- a/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp +++ b/src/duckdb/src/core_functions/aggregate/distributive/sum.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { @@ -80,6 +81,7 @@ void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) { + function.return_type = deserializer.Get(); return nullptr; } diff --git a/src/duckdb/src/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/src/core_functions/aggregate/holistic/approx_top_k.cpp new file mode 100644 index 00000000..19b3ae88 --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/approx_top_k.cpp @@ -0,0 +1,388 @@ +#include "duckdb/core_functions/aggregate/histogram_helpers.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/core_functions/aggregate/sort_key_helpers.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/printer.hpp" + +namespace duckdb { + +struct ApproxTopKString { + ApproxTopKString() : str(UINT32_C(0)), hash(0) { + } + ApproxTopKString(string_t str_p, hash_t hash_p) : str(str_p), hash(hash_p) { + } + + string_t str; + hash_t hash; +}; + +struct ApproxTopKHash { + std::size_t operator()(const ApproxTopKString &k) const { + return k.hash; + } +}; + +struct ApproxTopKEquality { + bool operator()(const ApproxTopKString &a, const ApproxTopKString &b) const { + return Equals::Operation(a.str, b.str); + } +}; + +template +using approx_topk_map_t = unordered_map; + +// approx top k algorithm based on "A parallel space saving algorithm for frequent items and the Hurwitz zeta +// distribution" arxiv link - https://arxiv.org/pdf/1401.0702 +// together with the filter extension (Filtered Space-Saving) from "Estimating Top-k Destinations in Data Streams" +struct ApproxTopKValue { + //! The counter + idx_t count = 0; + //! Index in the values array + idx_t index = 0; + //! The string value + ApproxTopKString str_val; + //! Allocated data + char *dataptr = nullptr; + uint32_t size = 0; + uint32_t capacity = 0; +}; + +struct ApproxTopKState { + // the top-k data structure has two components + // a list of k values sorted on "count" (i.e. values[0] has the lowest count) + // a lookup map: string_t -> idx in "values" array + unsafe_unique_array stored_values; + unsafe_vector> values; + approx_topk_map_t> lookup_map; + unsafe_vector filter; + idx_t k = 0; + idx_t capacity = 0; + idx_t filter_mask; + + void Initialize(idx_t kval) { + static constexpr idx_t MONITORED_VALUES_RATIO = 3; + static constexpr idx_t FILTER_RATIO = 8; + + D_ASSERT(values.empty()); + D_ASSERT(lookup_map.empty()); + k = kval; + capacity = kval * MONITORED_VALUES_RATIO; + stored_values = make_unsafe_uniq_array_uninitialized(capacity); + values.reserve(capacity); + + // we scale the filter based on the amount of values we are monitoring + idx_t filter_size = NextPowerOfTwo(capacity * FILTER_RATIO); + filter_mask = filter_size - 1; + filter.resize(filter_size); + } + + static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) { + value.str_val.hash = input.hash; + if (input.str.IsInlined()) { + // no need to copy + value.str_val = input; + return; + } + value.size = UnsafeNumericCast(input.str.GetSize()); + if (value.size > value.capacity) { + // need to re-allocate for this value + value.capacity = UnsafeNumericCast(NextPowerOfTwo(value.size)); + value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity)); + } + // copy over the data + memcpy(value.dataptr, input.str.GetData(), value.size); + value.str_val.str = string_t(value.dataptr, value.size); + } + + void InsertOrReplaceEntry(const ApproxTopKString &input, AggregateInputData &aggr_input, idx_t increment = 1) { + if (values.size() < capacity) { + D_ASSERT(increment > 0); + // we can always add this entry + auto &val = stored_values[values.size()]; + val.index = values.size(); + values.push_back(val); + } + auto &value = values.back().get(); + if (value.count > 0) { + // the capacity is reached - we need to replace an entry + + // we use the filter as an early out + // based on the hash - we find a slot in the filter + // instead of monitoring the value immediately, we add to the slot in the filter + // ONLY when the value in the filter exceeds the current min value, we start monitoring the value + // this speeds up the algorithm as switching monitor values means we need to erase/insert in the hash table + auto &filter_value = filter[input.hash & filter_mask]; + if (filter_value + increment < value.count) { + // if the filter has a lower count than the current min count + // we can skip adding this entry (for now) + filter_value += increment; + return; + } + // the filter exceeds the min value - start monitoring this value + // erase the existing entry from the map + // and set the filter for the minimum value back to the current minimum value + filter[value.str_val.hash & filter_mask] = value.count; + lookup_map.erase(value.str_val); + } + CopyValue(value, input, aggr_input); + lookup_map.insert(make_pair(value.str_val, reference(value))); + IncrementCount(value, increment); + } + + void IncrementCount(ApproxTopKValue &value, idx_t increment = 1) { + value.count += increment; + // maintain sortedness of "values" + // swap while we have a higher count than the next entry + while (value.index > 0 && values[value.index].get().count > values[value.index - 1].get().count) { + // swap the elements around + auto &left = values[value.index]; + auto &right = values[value.index - 1]; + std::swap(left.get().index, right.get().index); + std::swap(left, right); + } + } + + void Verify() const { +#ifdef DEBUG + if (values.empty()) { + D_ASSERT(lookup_map.empty()); + return; + } + D_ASSERT(values.size() <= capacity); + for (idx_t k = 0; k < values.size(); k++) { + auto &val = values[k].get(); + D_ASSERT(val.count > 0); + // verify map exists + auto entry = lookup_map.find(val.str_val); + D_ASSERT(entry != lookup_map.end()); + // verify the index is correct + D_ASSERT(val.index == k); + if (k > 0) { + // sortedness + D_ASSERT(val.count <= values[k - 1].get().count); + } + } + // verify lookup map does not contain extra entries + D_ASSERT(lookup_map.size() == values.size()); +#endif + } +}; + +struct ApproxTopKOperation { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void Operation(STATE &state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector, + idx_t offset, idx_t count) { + if (state.values.empty()) { + static constexpr int64_t MAX_APPROX_K = 1000000; + // not initialized yet - initialize the K value and set all counters to 0 + UnifiedVectorFormat kdata; + top_k_vector.ToUnifiedFormat(count, kdata); + auto kidx = kdata.sel->get_index(offset); + if (!kdata.validity.RowIsValid(kidx)) { + throw InvalidInputException("Invalid input for approx_top_k: k value cannot be NULL"); + } + auto kval = UnifiedVectorFormat::GetData(kdata)[kidx]; + if (kval <= 0) { + throw InvalidInputException("Invalid input for approx_top_k: k value must be > 0"); + } + if (kval >= MAX_APPROX_K) { + throw InvalidInputException("Invalid input for approx_top_k: k value must be < %d", MAX_APPROX_K); + } + state.Initialize(UnsafeNumericCast(kval)); + } + ApproxTopKString topk_string(input, Hash(input)); + auto entry = state.lookup_map.find(topk_string); + if (entry != state.lookup_map.end()) { + // the input is monitored - increment the count + state.IncrementCount(entry->second.get()); + } else { + // the input is not monitored - replace the first entry with the current entry and increment + state.InsertOrReplaceEntry(topk_string, aggr_input); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input) { + if (source.values.empty()) { + // source is empty + return; + } + source.Verify(); + auto min_source = source.values.back().get().count; + idx_t min_target; + if (target.values.empty()) { + min_target = 0; + target.Initialize(source.k); + } else { + if (source.k != target.k) { + throw NotImplementedException("Approx Top K - cannot combine approx_top_K with different k values. " + "K values must be the same for all entries within the same group"); + } + min_target = target.values.back().get().count; + } + // for all entries in target + // check if they are tracked in source + // if they do - add the tracked count + // if they do not - add the minimum count + for (idx_t target_idx = 0; target_idx < target.values.size(); target_idx++) { + auto &val = target.values[target_idx].get(); + auto source_entry = source.lookup_map.find(val.str_val); + idx_t increment = min_source; + if (source_entry != source.lookup_map.end()) { + increment = source_entry->second.get().count; + } + if (increment == 0) { + continue; + } + target.IncrementCount(val, increment); + } + // now for each entry in source, if it is not tracked by the target, at the target minimum + for (auto &source_entry : source.values) { + auto &source_val = source_entry.get(); + auto target_entry = target.lookup_map.find(source_val.str_val); + if (target_entry != target.lookup_map.end()) { + // already tracked - no need to add anything + continue; + } + auto new_count = source_val.count + min_target; + idx_t increment; + if (target.values.size() >= target.capacity) { + idx_t current_min = target.values.empty() ? 0 : target.values.back().get().count; + D_ASSERT(target.values.size() == target.capacity); + // target already has capacity values + // check if we should insert this entry + if (new_count <= current_min) { + // if we do not we can skip this entry + continue; + } + increment = new_count - current_min; + } else { + // target does not have capacity entries yet + // just add this entry with the full count + increment = new_count; + } + target.InsertOrReplaceEntry(source_val.str_val, aggr_input, increment); + } + // copy over the filter + D_ASSERT(source.filter.size() == target.filter.size()); + for (idx_t filter_idx = 0; filter_idx < source.filter.size(); filter_idx++) { + target.filter[filter_idx] += source.filter[filter_idx]; + } + target.Verify(); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } + + static bool IgnoreNull() { + return true; + } +}; + +template +static void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, + idx_t count) { + using STATE = ApproxTopKState; + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto &top_k_vector = inputs[1]; + + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + + auto states = UnifiedVectorFormat::GetData(sdata); + auto data = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + ApproxTopKOperation::Operation(state, data[idx], aggr_input, top_k_vector, i, count); + } +} + +template +static void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData(sdata); + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (state.values.empty()) { + continue; + } + // get up to k values for each state + // this can be less of fewer unique values were found + new_entries += MinValue(state.values.size(), state.k); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto list_entries = FlatVector::GetData(result); + auto &child_data = ListVector::GetEntry(result); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = *states[sdata.sel->get_index(i)]; + if (state.values.empty()) { + mask.SetInvalid(rid); + continue; + } + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + for (idx_t val_idx = 0; val_idx < MinValue(state.values.size(), state.k); val_idx++) { + auto &val = state.values[val_idx].get(); + D_ASSERT(val.count > 0); + OP::template HistogramFinalize(val.str_val.str, child_data, current_offset); + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); +} + +unique_ptr ApproxTopKBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { + function.update = ApproxTopKUpdate; + function.finalize = ApproxTopKFinalize; + } + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return nullptr; +} + +AggregateFunction ApproxTopKFun::GetFunction() { + using STATE = ApproxTopKState; + using OP = ApproxTopKOperation; + return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize, + AggregateFunction::StateInitialize, ApproxTopKUpdate, + AggregateFunction::StateCombine, ApproxTopKFinalize, nullptr, ApproxTopKBind, + AggregateFunction::StateDestroy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp index edda3458..5b6abcd2 100644 --- a/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/src/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -133,33 +133,56 @@ struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { } }; -AggregateFunction GetApproximateQuantileAggregateFunction(PhysicalType type) { - switch (type) { +static AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) { + // Not binary comparable + if (type == LogicalType::TIME_TZ) { + return AggregateFunction::UnaryAggregateDestructor(type, type); + } + switch (type.InternalType()) { + case PhysicalType::INT8: + return AggregateFunction::UnaryAggregateDestructor(type, type); case PhysicalType::INT16: return AggregateFunction::UnaryAggregateDestructor(LogicalType::SMALLINT, - LogicalType::SMALLINT); + ApproxQuantileScalarOperation>(type, type); case PhysicalType::INT32: return AggregateFunction::UnaryAggregateDestructor(LogicalType::INTEGER, - LogicalType::INTEGER); + ApproxQuantileScalarOperation>(type, type); case PhysicalType::INT64: return AggregateFunction::UnaryAggregateDestructor(LogicalType::BIGINT, - LogicalType::BIGINT); + ApproxQuantileScalarOperation>(type, type); case PhysicalType::INT128: return AggregateFunction::UnaryAggregateDestructor(LogicalType::HUGEINT, - LogicalType::HUGEINT); + ApproxQuantileScalarOperation>(type, type); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregateDestructor(type, type); case PhysicalType::DOUBLE: return AggregateFunction::UnaryAggregateDestructor(LogicalType::DOUBLE, - LogicalType::DOUBLE); + ApproxQuantileScalarOperation>(type, type); default: throw InternalException("Unimplemented quantile aggregate"); } } +static AggregateFunction GetApproximateQuantileDecimalAggregateFunction(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::INT8: + return GetApproximateQuantileAggregateFunction(LogicalType::TINYINT); + case PhysicalType::INT16: + return GetApproximateQuantileAggregateFunction(LogicalType::SMALLINT); + case PhysicalType::INT32: + return GetApproximateQuantileAggregateFunction(LogicalType::INTEGER); + case PhysicalType::INT64: + return GetApproximateQuantileAggregateFunction(LogicalType::BIGINT); + case PhysicalType::INT128: + return GetApproximateQuantileAggregateFunction(LogicalType::HUGEINT); + default: + throw InternalException("Unimplemented quantile decimal aggregate"); + } +} + static float CheckApproxQuantile(const Value &quantile_val) { if (quantile_val.IsNull()) { throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL"); @@ -210,14 +233,14 @@ unique_ptr BindApproxQuantile(ClientContext &context, AggregateFun unique_ptr BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto bind_data = BindApproxQuantile(context, function, arguments); - function = GetApproximateQuantileAggregateFunction(arguments[0]->return_type.InternalType()); + function = GetApproximateQuantileDecimalAggregateFunction(arguments[0]->return_type); function.name = "approx_quantile"; function.serialize = ApproximateQuantileBindData::Serialize; function.deserialize = ApproximateQuantileBindData::Deserialize; return bind_data; } -AggregateFunction GetApproximateQuantileAggregate(PhysicalType type) { +AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { auto fun = GetApproximateQuantileAggregateFunction(type); fun.bind = BindApproxQuantile; fun.serialize = ApproximateQuantileBindData::Serialize; @@ -287,9 +310,16 @@ AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type case LogicalTypeId::SMALLINT: return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::INTEGER: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::BIGINT: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::TIME_TZ: + // Not binary comparable + return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::HUGEINT: return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::FLOAT: @@ -307,10 +337,9 @@ AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type case PhysicalType::INT128: return GetTypedApproxQuantileListAggregateFunction(type); default: - throw NotImplementedException("Unimplemented approximate quantile list aggregate"); + throw NotImplementedException("Unimplemented approximate quantile list decimal aggregate"); } default: - // TODO: Add quantitative temporal types throw NotImplementedException("Unimplemented approximate quantile list aggregate"); } } @@ -342,11 +371,17 @@ AggregateFunctionSet ApproxQuantileFun::GetFunctions() { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT16)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT32)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT64)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::INT128)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(PhysicalType::DOUBLE)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::SMALLINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::INTEGER)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::BIGINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::HUGEINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DOUBLE)); + + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DATE)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME_TZ)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP_TZ)); // List variants approx_quantile.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, @@ -360,6 +395,13 @@ AggregateFunctionSet ApproxQuantileFun::GetFunctions() { approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT)); approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT)); approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE)); + + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::DATE)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME_TZ)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP_TZ)); + return approx_quantile; } diff --git a/src/duckdb/src/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/src/core_functions/aggregate/holistic/mad.cpp new file mode 100644 index 00000000..8be7415f --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/holistic/mad.cpp @@ -0,0 +1,330 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "duckdb/core_functions/aggregate/quantile_state.hpp" + +namespace duckdb { + +struct FrameSet { + inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) { + } + + inline idx_t Size() const { + idx_t result = 0; + for (const auto &frame : frames) { + result += frame.end - frame.start; + } + + return result; + } + + inline bool Contains(idx_t i) const { + for (idx_t f = 0; f < frames.size(); ++f) { + const auto &frame = frames[f]; + if (frame.start <= i && i < frame.end) { + return true; + } + } + return false; + } + const SubFrames &frames; +}; + +struct QuantileReuseUpdater { + idx_t *index; + idx_t j; + + inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) { + } + + inline void Neither(idx_t begin, idx_t end) { + } + + inline void Left(idx_t begin, idx_t end) { + } + + inline void Right(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + index[j++] = begin; + } + } + + inline void Both(idx_t begin, idx_t end) { + } +}; + +void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { + + // Copy overlapping indices by scanning the previous set and copying down into holes. + // We copy instead of leaving gaps in case there are fewer values in the current frame. + FrameSet prev_set(prevs); + FrameSet curr_set(currs); + const auto prev_count = prev_set.Size(); + idx_t j = 0; + for (idx_t p = 0; p < prev_count; ++p) { + auto idx = index[p]; + + // Shift down into any hole + if (j != p) { + index[j] = idx; + } + + // Skip overlapping values + if (curr_set.Contains(idx)) { + ++j; + } + } + + // Insert new indices + if (j > 0) { + QuantileReuseUpdater updater(index, j); + AggregateExecutor::IntersectFrames(prevs, currs, updater); + } else { + // No overlap: overwrite with new values + for (const auto &curr : currs) { + for (auto idx = curr.start; idx < curr.end; ++idx) { + index[j++] = idx; + } + } + } +} + +//===--------------------------------------------------------------------===// +// Median Absolute Deviation +//===--------------------------------------------------------------------===// +template +struct MadAccessor { + using INPUT_TYPE = T; + using RESULT_TYPE = R; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const RESULT_TYPE delta = input - UnsafeNumericCast(median); + return TryAbsOperator::Operation(delta); + } +}; + +// hugeint_t - double => undefined +template <> +struct MadAccessor { + using INPUT_TYPE = hugeint_t; + using RESULT_TYPE = double; + using MEDIAN_TYPE = double; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = Hugeint::Cast(input) - median; + return TryAbsOperator::Operation(delta); + } +}; + +// date_t - timestamp_t => interval_t +template <> +struct MadAccessor { + using INPUT_TYPE = date_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto dt = Cast::Operation(input); + const auto delta = dt - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// timestamp_t - timestamp_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = timestamp_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// dtime_t - dtime_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = dtime_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = dtime_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +template +struct MedianAbsoluteDeviationOperation : QuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + using INPUT_TYPE = typename STATE::InputType; + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &q = bind_data.quantiles[0]; + Interpolator interp(q, state.v.size(), false); + const auto med = interp.template Operation(state.v.data(), finalize_data.result); + + MadAccessor accessor(med); + target = interp.template Operation(state.v.data(), finalize_data.result, accessor); + } + + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, + idx_t ridx, const STATE *gstate) { + auto rdata = FlatVector::GetData(result); + + QuantileIncluded included(fmask, dmask); + const auto n = FrameSize(included, frames); + + if (!n) { + auto &rmask = FlatVector::Validity(result); + rmask.Set(ridx, false); + return; + } + + // Compute the median + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &quantile = bind_data.quantiles[0]; + auto &window_state = state.GetOrCreateWindowState(); + MEDIAN_TYPE med; + if (gstate && gstate->HasTrees()) { + med = gstate->GetWindowState().template WindowScalar(data, frames, n, result, quantile); + } else { + window_state.UpdateSkip(data, frames, included); + med = window_state.template WindowScalar(data, frames, n, result, quantile); + } + + // Lazily initialise frame state + window_state.SetCount(frames.back().end - frames.front().start); + auto index2 = window_state.m.data(); + D_ASSERT(index2); + + // The replacement trick does not work on the second index because if + // the median has changed, the previous order is not correct. + // It is probably close, however, and so reuse is helpful. + auto &prevs = window_state.prevs; + ReuseIndexes(index2, frames, prevs); + std::partition(index2, index2 + window_state.count, included); + + Interpolator interp(quantile, n, false); + + // Compute mad from the second index + using ID = QuantileIndirect; + ID indirect(data); + + using MAD = MadAccessor; + MAD mad(med); + + using MadIndirect = QuantileComposed; + MadIndirect mad_indirect(mad, indirect); + rdata[ridx] = interp.template Operation(index2, result, mad_indirect); + + // Prev is used by both skip lists and increments + prevs = frames; + } +}; + +unique_ptr BindMAD(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); +} + +template +AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, + const LogicalType &target_type) { + using STATE = QuantileState; + using OP = MedianAbsoluteDeviationOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.bind = BindMAD; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; + return fun; +} + +AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::FLOAT: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DOUBLE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT32: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT64: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT128: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); + } + break; + + case LogicalTypeId::DATE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction( + type, LogicalType::INTERVAL); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); + } +} + +unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); + function.name = "mad"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return BindMAD(context, function, arguments); +} + +AggregateFunctionSet MadFun::GetFunctions() { + AggregateFunctionSet mad("mad"); + mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); + + const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, + LogicalType::TIME_TZ}; + for (const auto &type : MAD_TYPES) { + mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); + } + return mad; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp index f33ccc41..a8d0dbf1 100644 --- a/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp +++ b/src/duckdb/src/core_functions/aggregate/holistic/mode.cpp @@ -1,7 +1,3 @@ -// MODE( ) -// Returns the most frequent value for the values within expr1. -// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. - #include "duckdb/common/exception.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -9,9 +5,15 @@ #include "duckdb/core_functions/aggregate/holistic_functions.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/common/unordered_map.hpp" - +#include "duckdb/common/owning_string_map.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" +#include "duckdb/core_functions/aggregate/sort_key_helpers.hpp" #include +// MODE( ) +// Returns the most frequent value for the values within expr1. +// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. + namespace std { template <> @@ -42,15 +44,49 @@ struct hash { namespace duckdb { -template +struct ModeAttr { + ModeAttr() : count(0), first_row(std::numeric_limits::max()) { + } + size_t count; + idx_t first_row; +}; + +template +struct ModeStandard { + using MAP_TYPE = unordered_map; + + static MAP_TYPE *CreateEmpty(ArenaAllocator &) { + return new MAP_TYPE(); + } + static MAP_TYPE *CreateEmpty(Allocator &) { + return new MAP_TYPE(); + } + + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return RESULT_TYPE(input); + } +}; + +struct ModeString { + using MAP_TYPE = OwningStringMap; + + static MAP_TYPE *CreateEmpty(ArenaAllocator &allocator) { + return new MAP_TYPE(allocator); + } + static MAP_TYPE *CreateEmpty(Allocator &allocator) { + return new MAP_TYPE(allocator); + } + + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return StringVector::AddStringOrBlob(result, input); + } +}; + +template struct ModeState { - struct ModeAttr { - ModeAttr() : count(0), first_row(std::numeric_limits::max()) { - } - size_t count; - idx_t first_row; - }; - using Counts = unordered_map; + using Counts = typename TYPE_OP::MAP_TYPE; ModeState() { } @@ -72,8 +108,9 @@ struct ModeState { } void Reset() { - Counts empty; - frequency_map->swap(empty); + if (frequency_map) { + frequency_map->clear(); + } nonzero = 0; count = 0; valid = false; @@ -137,37 +174,27 @@ struct ModeIncluded { const ValidityMask &dmask; }; -struct ModeAssignmentStandard { - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return RESULT_TYPE(input); - } -}; - -struct ModeAssignmentString { - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return StringVector::AddString(result, input); - } -}; - -template -struct ModeFunction { +template +struct BaseModeFunction { template static void Initialize(STATE &state) { new (&state) STATE(); } template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts(); + state.frequency_map = TYPE_OP::CreateEmpty(input_data.allocator); } - auto key = KEY_TYPE(input); auto &i = (*state.frequency_map)[key]; - i.count++; + ++i.count; i.first_row = MinValue(i.first_row, state.count); - state.count++; + ++state.count; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input) { + Execute(state, key, aggr_input.input); } template @@ -188,6 +215,18 @@ struct ModeFunction { target.count += source.count; } + static bool IgnoreNull() { + return true; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } +}; + +template +struct ModeFunction : BaseModeFunction { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (!state.frequency_map) { @@ -196,17 +235,17 @@ struct ModeFunction { } auto highest_frequency = state.Scan(); if (highest_frequency != state.frequency_map->end()) { - target = ASSIGN_OP::template Assign(finalize_data.result, highest_frequency->first); + target = TYPE_OP::template Assign(finalize_data.result, highest_frequency->first); } else { finalize_data.ReturnNull(); } } + template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + static void ConstantOperation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input, idx_t count) { if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts(); + state.frequency_map = TYPE_OP::CreateEmpty(aggr_input.input.allocator); } - auto key = KEY_TYPE(input); auto &i = (*state.frequency_map)[key]; i.count += count; i.first_row = MinValue(i.first_row, state.count); @@ -229,7 +268,7 @@ struct ModeFunction { inline void Left(idx_t begin, idx_t end) { for (; begin < end; ++begin) { if (included(begin)) { - state.ModeRm(KEY_TYPE(data[begin]), begin); + state.ModeRm(data[begin], begin); } } } @@ -237,7 +276,7 @@ struct ModeFunction { inline void Right(idx_t begin, idx_t end) { for (; begin < end; ++begin) { if (included(begin)) { - state.ModeAdd(KEY_TYPE(data[begin]), begin); + state.ModeAdd(data[begin], begin); } } } @@ -260,17 +299,17 @@ struct ModeFunction { ModeIncluded included(fmask, dmask); if (!state.frequency_map) { - state.frequency_map = new typename STATE::Counts; + state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator()); } - const double tau = .25; - if (state.nonzero <= tau * state.frequency_map->size() || prevs.back().end <= frames.front().start || + const size_t tau_inverse = 4; // tau==0.25 + if (state.nonzero <= (state.frequency_map->size() / tau_inverse) || prevs.back().end <= frames.front().start || frames.back().end <= prevs.front().start) { state.Reset(); // for f ∈ F do for (const auto &frame : frames) { for (auto i = frame.start; i < frame.end; ++i) { if (included(i)) { - state.ModeAdd(KEY_TYPE(data[i]), i); + state.ModeAdd(data[i], i); } } } @@ -291,100 +330,100 @@ struct ModeFunction { } if (state.valid) { - rdata[rid] = ASSIGN_OP::template Assign(result, *state.mode); + rdata[rid] = TYPE_OP::template Assign(result, *state.mode); } else { rmask.Set(rid, false); } prevs = frames; } +}; - static bool IgnoreNull() { - return true; - } - +template +struct ModeFallbackFunction : BaseModeFunction { template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.frequency_map) { + finalize_data.ReturnNull(); + return; + } + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + CreateSortKeyHelpers::DecodeSortKey(highest_frequency->first, finalize_data.result, + finalize_data.result_idx, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + } else { + finalize_data.ReturnNull(); + } } }; -template +template > AggregateFunction GetTypedModeFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = ModeFunction; - auto return_type = type.id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : type; - auto func = AggregateFunction::UnaryAggregateDestructor(type, return_type); + using STATE = ModeState; + using OP = ModeFunction; + auto func = AggregateFunction::UnaryAggregateDestructor(type, type); func.window = AggregateFunction::UnaryWindow; return func; } +AggregateFunction GetFallbackModeFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = ModeFallbackFunction; + AggregateFunction aggr({type}, type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr); + aggr.destructor = AggregateFunction::StateDestroy; + return aggr; +} + AggregateFunction GetModeAggregate(const LogicalType &type) { switch (type.InternalType()) { case PhysicalType::INT8: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::UINT8: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::INT16: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::UINT16: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::INT32: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::UINT32: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::INT64: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::UINT64: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::INT128: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::UINT128: - return GetTypedModeFunction(type); - + return GetTypedModeFunction(type); case PhysicalType::FLOAT: - return GetTypedModeFunction(type); + return GetTypedModeFunction(type); case PhysicalType::DOUBLE: - return GetTypedModeFunction(type); - + return GetTypedModeFunction(type); case PhysicalType::INTERVAL: - return GetTypedModeFunction(type); - + return GetTypedModeFunction(type); case PhysicalType::VARCHAR: - return GetTypedModeFunction( - LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)); - + return GetTypedModeFunction(type); default: - throw NotImplementedException("Unimplemented mode aggregate"); + return GetFallbackModeFunction(type); } } -unique_ptr BindModeDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { +unique_ptr BindModeAggregate(ClientContext &context, AggregateFunction &function, + vector> &arguments) { function = GetModeAggregate(arguments[0]->return_type); function.name = "mode"; return nullptr; } AggregateFunctionSet ModeFun::GetFunctions() { - const vector TEMPORAL = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::TIME, - LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, LogicalType::INTERVAL}; - AggregateFunctionSet mode; - mode.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindModeDecimal)); - - for (const auto &type : LogicalType::Numeric()) { - if (type.id() != LogicalTypeId::DECIMAL) { - mode.AddFunction(GetModeAggregate(type)); - } - } - - for (const auto &type : TEMPORAL) { - mode.AddFunction(GetModeAggregate(type)); - } - - mode.AddFunction(GetModeAggregate(LogicalType::VARCHAR)); + mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, BindModeAggregate)); return mode; } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp index e0150700..779ce4de 100644 --- a/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp +++ b/src/duckdb/src/core_functions/aggregate/holistic/quantile.cpp @@ -1,141 +1,18 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/execution/merge_sort_tree.hpp" #include "duckdb/core_functions/aggregate/quantile_enum.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/operator/abs.hpp" -#include "duckdb/common/operator/multiply.hpp" - +#include "duckdb/core_functions/aggregate/quantile_state.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/queue.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" - -#include "SkipList.h" - -#include -#include -#include -#include +#include "duckdb/core_functions/aggregate/sort_key_helpers.hpp" namespace duckdb { -// Interval arithmetic -static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT - D_ASSERT(d >= 0 && d <= 1); - return Interval::FromMicro(std::llround(Interval::GetMicro(i) * d)); -} - -inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); -} - -inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); -} - -struct FrameSet { - inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) { - } - - inline idx_t Size() const { - idx_t result = 0; - for (const auto &frame : frames) { - result += frame.end - frame.start; - } - - return result; - } - - inline bool Contains(idx_t i) const { - for (idx_t f = 0; f < frames.size(); ++f) { - const auto &frame = frames[f]; - if (frame.start <= i && i < frame.end) { - return true; - } - } - return false; - } - const SubFrames &frames; -}; - -struct QuantileIncluded { - inline explicit QuantileIncluded(const ValidityMask &fmask_p, const ValidityMask &dmask_p) - : fmask(fmask_p), dmask(dmask_p) { - } - - inline bool operator()(const idx_t &idx) const { - return fmask.RowIsValid(idx) && dmask.RowIsValid(idx); - } - - inline bool AllValid() const { - return fmask.AllValid() && dmask.AllValid(); - } - - const ValidityMask &fmask; - const ValidityMask &dmask; -}; - -struct QuantileReuseUpdater { - idx_t *index; - idx_t j; - - inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) { - } - - inline void Neither(idx_t begin, idx_t end) { - } - - inline void Left(idx_t begin, idx_t end) { - } - - inline void Right(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - index[j++] = begin; - } - } - - inline void Both(idx_t begin, idx_t end) { - } -}; - -void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { - - // Copy overlapping indices by scanning the previous set and copying down into holes. - // We copy instead of leaving gaps in case there are fewer values in the current frame. - FrameSet prev_set(prevs); - FrameSet curr_set(currs); - const auto prev_count = prev_set.Size(); - idx_t j = 0; - for (idx_t p = 0; p < prev_count; ++p) { - auto idx = index[p]; - - // Shift down into any hole - if (j != p) { - index[j] = idx; - } - - // Skip overlapping values - if (curr_set.Contains(idx)) { - ++j; - } - } - - // Insert new indices - if (j > 0) { - QuantileReuseUpdater updater(index, j); - AggregateExecutor::IntersectFrames(prevs, currs, updater); - } else { - // No overlap: overwrite with new values - for (const auto &curr : currs) { - for (auto idx = curr.start; idx < curr.end; ++idx) { - index[j++] = idx; - } - } - } -} - template struct IndirectLess { inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { @@ -148,262 +25,6 @@ struct IndirectLess { const INPUT_TYPE *inputs; }; -struct CastInterpolation { - - template - static inline TARGET_TYPE Cast(const INPUT_TYPE &src, Vector &result) { - return Cast::Operation(src); - } - template - static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { - const auto delta = hi - lo; - return UnsafeNumericCast(lo + delta * d); - } -}; - -template <> -interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { - return {0, 0, src.micros}; -} - -template <> -double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { - return lo * (1.0 - d) + hi * d; -} - -template <> -dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { - return dtime_t(std::llround(lo.micros * (1.0 - d) + hi.micros * d)); -} - -template <> -timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { - return timestamp_t(std::llround(lo.value * (1.0 - d) + hi.value * d)); -} - -template <> -hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { - return Hugeint::Convert(Interpolate(Hugeint::Cast(lo), d, Hugeint::Cast(hi))); -} - -template <> -interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { - const interval_t delta = hi - lo; - return lo + MultiplyByDouble(delta, d); -} - -template <> -string_t CastInterpolation::Cast(const std::string &src, Vector &result) { - return StringVector::AddString(result, src); -} - -template <> -string_t CastInterpolation::Cast(const string_t &src, Vector &result) { - return StringVector::AddString(result, src); -} - -// Direct access -template -struct QuantileDirect { - using INPUT_TYPE = T; - using RESULT_TYPE = T; - - inline const INPUT_TYPE &operator()(const INPUT_TYPE &x) const { - return x; - } -}; - -// Indirect access -template -struct QuantileIndirect { - using INPUT_TYPE = idx_t; - using RESULT_TYPE = T; - const RESULT_TYPE *data; - - explicit QuantileIndirect(const RESULT_TYPE *data_p) : data(data_p) { - } - - inline RESULT_TYPE operator()(const idx_t &input) const { - return data[input]; - } -}; - -// Composed access -template -struct QuantileComposed { - using INPUT_TYPE = typename INNER::INPUT_TYPE; - using RESULT_TYPE = typename OUTER::RESULT_TYPE; - - const OUTER &outer; - const INNER &inner; - - explicit QuantileComposed(const OUTER &outer_p, const INNER &inner_p) : outer(outer_p), inner(inner_p) { - } - - inline RESULT_TYPE operator()(const idx_t &input) const { - return outer(inner(input)); - } -}; - -// Accessed comparison -template -struct QuantileCompare { - using INPUT_TYPE = typename ACCESSOR::INPUT_TYPE; - const ACCESSOR &accessor; - const bool desc; - explicit QuantileCompare(const ACCESSOR &accessor_p, bool desc_p) : accessor(accessor_p), desc(desc_p) { - } - - inline bool operator()(const INPUT_TYPE &lhs, const INPUT_TYPE &rhs) const { - const auto lval = accessor(lhs); - const auto rval = accessor(rhs); - - return desc ? (rval < lval) : (lval < rval); - } -}; - -// Avoid using naked Values in inner loops... -struct QuantileValue { - explicit QuantileValue(const Value &v) : val(v), dbl(v.GetValue()) { - const auto &type = val.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - integral = IntegralValue::Get(v); - scaling = Hugeint::POWERS_OF_TEN[DecimalType::GetScale(type)]; - break; - } - default: - break; - } - } - - Value val; - - // DOUBLE - double dbl; - - // DECIMAL - hugeint_t integral; - hugeint_t scaling; -}; - -bool operator==(const QuantileValue &x, const QuantileValue &y) { - return x.val == y.val; -} - -// Continuous interpolation -template -struct Interpolator { - Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) - : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(UnsafeNumericCast(floor(RN))), - CRN(UnsafeNumericCast(ceil(RN))), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - if (lidx == hidx) { - return CastInterpolation::Cast(accessor(lidx), result); - } else { - auto lo = CastInterpolation::Cast(accessor(lidx), result); - auto hi = CastInterpolation::Cast(accessor(hidx), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - QuantileCompare comp(accessor, desc); - if (CRN == FRN) { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } else { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - std::nth_element(v_t + FRN, v_t + CRN, v_t + end, comp); - auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); - auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - template - inline TARGET_TYPE Extract(const INPUT_TYPE **dest, Vector &result) const { - if (CRN == FRN) { - return CastInterpolation::Cast(*dest[0], result); - } else { - auto lo = CastInterpolation::Cast(*dest[0], result); - auto hi = CastInterpolation::Cast(*dest[1], result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - const bool desc; - const double RN; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - -// Discrete "interpolation" -template <> -struct Interpolator { - static inline idx_t Index(const QuantileValue &q, const idx_t n) { - idx_t floored; - switch (q.val.type().id()) { - case LogicalTypeId::DECIMAL: { - // Integer arithmetic for accuracy - const auto integral = q.integral; - const auto scaling = q.scaling; - const auto scaled_q = - DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), integral); - const auto scaled_n = - DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), scaling); - floored = Cast::Operation((scaled_n - scaled_q) / scaling); - break; - } - default: - const auto scaled_q = (double)(n * q.dbl); - floored = UnsafeNumericCast(floor(n - scaled_q)); - break; - } - - return MaxValue(1, n - floored) - 1; - } - - Interpolator(const QuantileValue &q, const idx_t n_p, bool desc_p) - : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - return CastInterpolation::Cast(accessor(lidx), result); - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - QuantileCompare comp(accessor, desc); - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } - - template - TARGET_TYPE Extract(const INPUT_TYPE **dest, Vector &result) const { - return CastInterpolation::Cast(*dest[0], result); - } - - const bool desc; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - template static inline T QuantileAbs(const T &t) { return AbsOperator::Operation(t); @@ -435,1052 +56,511 @@ inline Value QuantileAbs(const Value &v) { } } -void BindQuantileInner(AggregateFunction &function, const LogicalType &type, QuantileSerializationType quantile_type); +//===--------------------------------------------------------------------===// +// Quantile Bind Data +//===--------------------------------------------------------------------===// +QuantileBindData::QuantileBindData() { +} -struct QuantileBindData : public FunctionData { - QuantileBindData() { - } +QuantileBindData::QuantileBindData(const Value &quantile_p) + : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { +} - explicit QuantileBindData(const Value &quantile_p) - : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { +QuantileBindData::QuantileBindData(const vector &quantiles_p) { + vector normalised; + size_t pos = 0; + size_t neg = 0; + for (idx_t i = 0; i < quantiles_p.size(); ++i) { + const auto &q = quantiles_p[i]; + pos += (q > 0); + neg += (q < 0); + normalised.emplace_back(QuantileAbs(q)); + order.push_back(i); } - - explicit QuantileBindData(const vector &quantiles_p) { - vector normalised; - size_t pos = 0; - size_t neg = 0; - for (idx_t i = 0; i < quantiles_p.size(); ++i) { - const auto &q = quantiles_p[i]; - pos += (q > 0); - neg += (q < 0); - normalised.emplace_back(QuantileAbs(q)); - order.push_back(i); - } - if (pos && neg) { - throw BinderException("QUANTILE parameters must have consistent signs"); - } - desc = (neg > 0); - - IndirectLess lt(normalised.data()); - std::sort(order.begin(), order.end(), lt); - - for (const auto &q : normalised) { - quantiles.emplace_back(QuantileValue(q)); - } + if (pos && neg) { + throw BinderException("QUANTILE parameters must have consistent signs"); } + desc = (neg > 0); - QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { - for (const auto &q : other.quantiles) { - quantiles.emplace_back(q); - } - } + IndirectLess lt(normalised.data()); + std::sort(order.begin(), order.end(), lt); - unique_ptr Copy() const override { - return make_uniq(*this); + for (const auto &q : normalised) { + quantiles.emplace_back(QuantileValue(q)); } +} - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return desc == other.desc && quantiles == other.quantiles && order == other.order; +QuantileBindData::QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { + for (const auto &q : other.quantiles) { + quantiles.emplace_back(q); } +} - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - vector raw; - for (const auto &q : bind_data.quantiles) { - raw.emplace_back(q.val); - } - serializer.WriteProperty(100, "quantiles", raw); - serializer.WriteProperty(101, "order", bind_data.order); - serializer.WriteProperty(102, "desc", bind_data.desc); - } +unique_ptr QuantileBindData::Copy() const { + return make_uniq(*this); +} - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - vector raw; - deserializer.ReadProperty(100, "quantiles", raw); - deserializer.ReadProperty(101, "order", result->order); - deserializer.ReadProperty(102, "desc", result->desc); - QuantileSerializationType deserialization_type; - deserializer.ReadPropertyWithDefault(103, "quantile_type", deserialization_type, - QuantileSerializationType::NON_DECIMAL); - - if (deserialization_type != QuantileSerializationType::NON_DECIMAL) { - LogicalType arg_type; - deserializer.ReadProperty(104, "logical_type", arg_type); - - BindQuantileInner(function, arg_type, deserialization_type); - } +bool QuantileBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return desc == other.desc && quantiles == other.quantiles && order == other.order; +} - for (const auto &r : raw) { - result->quantiles.emplace_back(QuantileValue(r)); - } - return std::move(result); +void QuantileBindData::Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + vector raw; + for (const auto &q : bind_data.quantiles) { + raw.emplace_back(q.val); } + serializer.WriteProperty(100, "quantiles", raw); + serializer.WriteProperty(101, "order", bind_data.order); + serializer.WriteProperty(102, "desc", bind_data.desc); +} - static void SerializeDecimalDiscrete(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - Serialize(serializer, bind_data_p, function); +unique_ptr QuantileBindData::Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + vector raw; + deserializer.ReadProperty(100, "quantiles", raw); + deserializer.ReadProperty(101, "order", result->order); + deserializer.ReadProperty(102, "desc", result->desc); + QuantileSerializationType deserialization_type; + deserializer.ReadPropertyWithExplicitDefault(103, "quantile_type", deserialization_type, + QuantileSerializationType::NON_DECIMAL); - serializer.WritePropertyWithDefault( - 103, "quantile_type", QuantileSerializationType::DECIMAL_DISCRETE, QuantileSerializationType::NON_DECIMAL); - serializer.WriteProperty(104, "logical_type", function.arguments[0]); + if (deserialization_type != QuantileSerializationType::NON_DECIMAL) { + deserializer.ReadDeletedProperty(104, "logical_type"); } - static void SerializeDecimalDiscreteList(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - - Serialize(serializer, bind_data_p, function); - serializer.WritePropertyWithDefault(103, "quantile_type", - QuantileSerializationType::DECIMAL_DISCRETE_LIST, - QuantileSerializationType::NON_DECIMAL); - serializer.WriteProperty(104, "logical_type", function.arguments[0]); + for (const auto &r : raw) { + result->quantiles.emplace_back(QuantileValue(r)); } - static void SerializeDecimalContinuous(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - Serialize(serializer, bind_data_p, function); - - serializer.WritePropertyWithDefault(103, "quantile_type", - QuantileSerializationType::DECIMAL_CONTINUOUS, - QuantileSerializationType::NON_DECIMAL); - serializer.WriteProperty(104, "logical_type", function.arguments[0]); - } - static void SerializeDecimalContinuousList(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { + return std::move(result); +} - Serialize(serializer, bind_data_p, function); +//===--------------------------------------------------------------------===// +// Cast Interpolation +//===--------------------------------------------------------------------===// +template <> +interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { + return {0, 0, src.micros}; +} - serializer.WritePropertyWithDefault( - 103, "quantile_type", QuantileSerializationType::DECIMAL_CONTINUOUS_LIST, - QuantileSerializationType::NON_DECIMAL); - serializer.WriteProperty(104, "logical_type", function.arguments[0]); - } +template <> +double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { + return lo * (1.0 - d) + hi * d; +} - vector quantiles; - vector order; - bool desc; -}; +template <> +dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { + return dtime_t(std::llround(static_cast(lo.micros) * (1.0 - d) + static_cast(hi.micros) * d)); +} -template -struct QuantileSortTree : public MergeSortTree { +template <> +timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { + return timestamp_t(std::llround(static_cast(lo.value) * (1.0 - d) + static_cast(hi.value) * d)); +} - using BaseTree = MergeSortTree; - using Elements = typename BaseTree::Elements; +template <> +hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { + return Hugeint::Convert(Interpolate(Hugeint::Cast(lo), d, Hugeint::Cast(hi))); +} - explicit QuantileSortTree(Elements &&lowest_level) : BaseTree(std::move(lowest_level)) { - } +static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT + D_ASSERT(d >= 0 && d <= 1); + return Interval::FromMicro(std::llround(static_cast(Interval::GetMicro(i)) * d)); +} - template - static unique_ptr WindowInit(const INPUT_TYPE *data, AggregateInputData &aggr_input_data, - const ValidityMask &data_mask, const ValidityMask &filter_mask, - idx_t count) { - // Build the indirection array - using ElementType = typename QuantileSortTree::ElementType; - vector sorted(count); - if (filter_mask.AllValid() && data_mask.AllValid()) { - std::iota(sorted.begin(), sorted.end(), 0); - } else { - size_t valid = 0; - QuantileIncluded included(filter_mask, data_mask); - for (ElementType i = 0; i < count; ++i) { - if (included(i)) { - sorted[valid++] = i; - } - } - sorted.resize(valid); - } +inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); +} - // Sort it - auto &bind_data = aggr_input_data.bind_data->Cast(); - using Accessor = QuantileIndirect; - Accessor indirect(data); - QuantileCompare cmp(indirect, bind_data.desc); - std::sort(sorted.begin(), sorted.end(), cmp); +inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); +} - return make_uniq(std::move(sorted)); - } +template <> +interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { + const interval_t delta = hi - lo; + return lo + MultiplyByDouble(delta, d); +} - inline IDX SelectNth(const SubFrames &frames, size_t n) const { - return BaseTree::NthElement(BaseTree::SelectNth(frames, n)); - } +template <> +string_t CastInterpolation::Cast(const string_t &src, Vector &result) { + return StringVector::AddStringOrBlob(result, src); +} - template - RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &result, - const QuantileValue &q) const { - D_ASSERT(n > 0); - - // Find the interpolated indicies within the frame - Interpolator interp(q, n, false); - const auto lo_data = SelectNth(frames, interp.FRN); - auto hi_data = lo_data; - if (interp.CRN != interp.FRN) { - hi_data = SelectNth(frames, interp.CRN); +//===--------------------------------------------------------------------===// +// Scalar Quantile +//===--------------------------------------------------------------------===// +template +struct QuantileScalarOperation : public QuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; } - - // Interpolate indirectly - using ID = QuantileIndirect; - ID indirect(data); - return interp.template Interpolate(lo_data, hi_data, result, indirect); + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); + target = interp.template Operation(state.v.data(), finalize_data.result); } - template - void WindowList(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, - const QuantileBindData &bind_data) const { - D_ASSERT(n > 0); - - // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - using ID = QuantileIndirect; - ID indirect(data); - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, n, false); - - const auto lo_data = SelectNth(frames, interp.FRN); - auto hi_data = lo_data; - if (interp.CRN != interp.FRN) { - hi_data = SelectNth(frames, interp.CRN); - } - - // Interpolate indirectly - rdata[lentry.offset + q] = - interp.template Interpolate(lo_data, hi_data, result, indirect); - } - } -}; - -template -struct PointerLess { - inline bool operator()(const T &lhi, const T &rhi) const { - return *lhi < *rhi; - } -}; - -template -struct QuantileState { - using SaveType = SAVE_TYPE; - using InputType = INPUT_TYPE; - - // Regular aggregation - vector v; - - // Windowed Quantile merge sort trees - using QuantileSortTree32 = QuantileSortTree; - using QuantileSortTree64 = QuantileSortTree; - unique_ptr qst32; - unique_ptr qst64; - - // Windowed Quantile skip lists - using PointerType = const InputType *; - using SkipListType = duckdb_skiplistlib::skip_list::HeadNode>; - SubFrames prevs; - unique_ptr s; - mutable vector dest; - - // Windowed MAD indirection - idx_t count; - vector m; - - QuantileState() : count(0) { - } - - ~QuantileState() { - } - - inline void SetCount(size_t count_p) { - count = count_p; - if (count >= m.size()) { - m.resize(count); - } - } - - inline SkipListType &GetSkipList(bool reset = false) { - if (reset || !s) { - s.reset(); - s = make_uniq(); - } - return *s; - } - - struct SkipListUpdater { - SkipListType &skip; - const INPUT_TYPE *data; - const QuantileIncluded &included; - - inline SkipListUpdater(SkipListType &skip, const INPUT_TYPE *data, const QuantileIncluded &included) - : skip(skip), data(data), included(included) { - } - - inline void Neither(idx_t begin, idx_t end) { - } - - inline void Left(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - skip.remove(data + begin); - } - } - } - - inline void Right(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - skip.insert(data + begin); - } - } - } - - inline void Both(idx_t begin, idx_t end) { - } - }; - - void UpdateSkip(const INPUT_TYPE *data, const SubFrames &frames, const QuantileIncluded &included) { - // No overlap, or no data - if (!s || prevs.back().end <= frames.front().start || frames.back().end <= prevs.front().start) { - auto &skip = GetSkipList(true); - for (const auto &frame : frames) { - for (auto i = frame.start; i < frame.end; ++i) { - if (included(i)) { - skip.insert(data + i); - } - } - } - } else { - auto &skip = GetSkipList(); - SkipListUpdater updater(skip, data, included); - AggregateExecutor::IntersectFrames(prevs, frames, updater); - } - } - - bool HasTrees() const { - return qst32 || qst64; - } - - template - RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &result, - const QuantileValue &q) const { - D_ASSERT(n > 0); - if (qst32) { - return qst32->WindowScalar(data, frames, n, result, q); - } else if (qst64) { - return qst64->WindowScalar(data, frames, n, result, q); - } else if (s) { - // Find the position(s) needed - try { - Interpolator interp(q, s->size(), false); - s->at(interp.FRN, interp.CRN - interp.FRN + 1, dest); - return interp.template Extract(dest.data(), result); - } catch (const duckdb_skiplistlib::skip_list::IndexError &idx_err) { - throw InternalException(idx_err.message()); - } - } else { - throw InternalException("No accelerator for scalar QUANTILE"); - } - } - - template - void WindowList(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, - const QuantileBindData &bind_data) const { - D_ASSERT(n > 0); - // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - rdata[lentry.offset + q] = WindowScalar(data, frames, n, result, quantile); - } - } -}; - -struct QuantileOperation { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - state.v.emplace_back(input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.v.empty()) { - return; - } - target.v.insert(target.v.end(), source.v.begin(), source.v.end()); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - static bool IgnoreNull() { - return true; - } - - template - static void WindowInit(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - data_ptr_t g_state) { - D_ASSERT(partition.input_count == 1); - - auto inputs = partition.inputs; - const auto count = partition.count; - const auto &filter_mask = partition.filter_mask; - const auto &stats = partition.stats; - - // If frames overlap significantly, then use local skip lists. - if (stats[0].end <= stats[1].begin) { - // Frames can overlap - const auto overlap = double(stats[1].begin - stats[0].end); - const auto cover = double(stats[1].end - stats[0].begin); - const auto ratio = overlap / cover; - if (ratio > .75) { - return; - } - } - - const auto data = FlatVector::GetData(inputs[0]); - const auto &data_mask = FlatVector::Validity(inputs[0]); - - // Build the tree - auto &state = *reinterpret_cast(g_state); - if (count < std::numeric_limits::max()) { - state.qst32 = QuantileSortTree::WindowInit(data, aggr_input_data, data_mask, - filter_mask, count); - } else { - state.qst64 = QuantileSortTree::WindowInit(data, aggr_input_data, data_mask, - filter_mask, count); - } - } - - static idx_t FrameSize(const QuantileIncluded &included, const SubFrames &frames) { - // Count the number of valid values - idx_t n = 0; - if (included.AllValid()) { - for (const auto &frame : frames) { - n += frame.end - frame.start; - } - } else { - // NULLs or FILTERed values, - for (const auto &frame : frames) { - for (auto i = frame.start; i < frame.end; ++i) { - n += included(i); - } - } - } - - return n; - } -}; - -template -static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT - LogicalType result_type = - LogicalType::LIST(child_type.id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -struct QuantileScalarOperation : public QuantileOperation { - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); - target = interp.template Operation(state.v.data(), finalize_data.result); - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, - idx_t ridx, const STATE *gstate) { - QuantileIncluded included(fmask, dmask); - const auto n = FrameSize(included, frames); + template + static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, + idx_t ridx, const STATE *gstate) { + QuantileIncluded included(fmask, dmask); + const auto n = FrameSize(included, frames); D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - - if (!n) { - rmask.Set(ridx, false); - return; - } - - const auto &quantile = bind_data.quantiles[0]; - if (gstate && gstate->HasTrees()) { - rdata[ridx] = gstate->template WindowScalar(data, frames, n, result, quantile); - } else { - // Update the skip list - state.UpdateSkip(data, frames, included); - - // Find the position(s) needed - rdata[ridx] = state.template WindowScalar(data, frames, n, result, quantile); - - // Save the previous state for next time - state.prevs = frames; - } - } -}; - -template -AggregateFunction GetTypedDiscreteQuantileAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto return_type = type.id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : type; - auto fun = AggregateFunction::UnaryAggregateDestructor(type, return_type); - fun.window = AggregateFunction::UnaryWindow; - fun.window_init = OP::WindowInit; - return fun; -} - -AggregateFunction GetDiscreteQuantileAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedDiscreteQuantileAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedDiscreteQuantileAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::INTERVAL: - return GetTypedDiscreteQuantileAggregateFunction(type); - case LogicalTypeId::ANY: - return GetTypedDiscreteQuantileAggregateFunction(type); - - default: - throw NotImplementedException("Unimplemented discrete quantile aggregate"); - } -} - -template -struct QuantileListOperation : public QuantileOperation { - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - auto v_t = state.v.data(); - D_ASSERT(v_t); - - auto &entry = target; - entry.offset = ridx; - idx_t lower = 0; - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.v.size(), bind_data.desc); - interp.begin = lower; - rdata[ridx + q] = interp.template Operation(v_t, result); - lower = interp.FRN; - } - entry.length = bind_data.quantiles.size(); - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } - - template - static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &list, - idx_t lidx, const STATE *gstate) { - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - QuantileIncluded included(fmask, dmask); - const auto n = FrameSize(included, frames); - - // Result is a constant LIST with a fixed length - if (!n) { - auto &lmask = FlatVector::Validity(list); - lmask.Set(lidx, false); - return; - } - - if (gstate && gstate->HasTrees()) { - gstate->template WindowList(data, frames, n, list, lidx, bind_data); - } else { - // - state.UpdateSkip(data, frames, included); - state.template WindowList(data, frames, n, list, lidx, bind_data); - state.prevs = frames; - } - } -}; - -template -AggregateFunction GetTypedDiscreteQuantileListAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(type, type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - fun.window_init = OP::template WindowInit; - return fun; -} - -AggregateFunction GetDiscreteQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedDiscreteQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::INTERVAL: - return GetTypedDiscreteQuantileListAggregateFunction(type); - case LogicalTypeId::ANY: - return GetTypedDiscreteQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } -} - -template -AggregateFunction GetTypedContinuousQuantileAggregateFunction(const LogicalType &input_type, - const LogicalType &target_type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - fun.window_init = OP::template WindowInit; - return fun; -} - -AggregateFunction GetContinuousQuantileAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SMALLINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::INTEGER: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::BIGINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::HUGEINT: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::FLOAT: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedContinuousQuantileAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); - } - case LogicalTypeId::DATE: - return GetTypedContinuousQuantileAggregateFunction(type, LogicalType::TIMESTAMP); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedContinuousQuantileAggregateFunction(type, type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedContinuousQuantileAggregateFunction(type, type); - - default: - throw NotImplementedException("Unimplemented continuous quantile aggregate"); - } -} - -template -AggregateFunction GetTypedContinuousQuantileListAggregateFunction(const LogicalType &input_type, - const LogicalType &result_type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(input_type, result_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - fun.window_init = OP::template WindowInit; - return fun; -} - -AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SMALLINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::INTEGER: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::BIGINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::HUGEINT: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::FLOAT: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented discrete quantile DECIMAL list aggregate"); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + + if (!n) { + rmask.Set(ridx, false); + return; } - case LogicalTypeId::DATE: - return GetTypedContinuousQuantileListAggregateFunction(type, LogicalType::TIMESTAMP); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedContinuousQuantileListAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented discrete quantile list aggregate"); - } -} -template -struct MadAccessor { - using INPUT_TYPE = T; - using RESULT_TYPE = R; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } + const auto &quantile = bind_data.quantiles[0]; + if (gstate && gstate->HasTrees()) { + rdata[ridx] = gstate->GetWindowState().template WindowScalar(data, frames, n, result, + quantile); + } else { + auto &window_state = state.GetOrCreateWindowState(); - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const RESULT_TYPE delta = input - UnsafeNumericCast(median); - return TryAbsOperator::Operation(delta); - } -}; + // Update the skip list + window_state.UpdateSkip(data, frames, included); -// hugeint_t - double => undefined -template <> -struct MadAccessor { - using INPUT_TYPE = hugeint_t; - using RESULT_TYPE = double; - using MEDIAN_TYPE = double; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = Hugeint::Cast(input) - median; - return TryAbsOperator::Operation(delta); - } -}; + // Find the position(s) needed + rdata[ridx] = window_state.template WindowScalar(data, frames, n, result, quantile); -// date_t - timestamp_t => interval_t -template <> -struct MadAccessor { - using INPUT_TYPE = date_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto dt = Cast::Operation(input); - const auto delta = dt - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); + // Save the previous state for next time + window_state.prevs = frames; + } } }; -// timestamp_t - timestamp_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = timestamp_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); +struct QuantileScalarFallback : QuantileOperation { + template + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { + state.AddElement(key, input_data); } -}; -// dtime_t - dtime_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = dtime_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = dtime_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); + auto interpolation_result = interp.InterpolateInternal(state.v.data()); + CreateSortKeyHelpers::DecodeSortKey(interpolation_result, finalize_data.result, finalize_data.result_idx, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); } }; -template -struct MedianAbsoluteDeviationOperation : public QuantileOperation { - +//===--------------------------------------------------------------------===// +// Quantile List +//===--------------------------------------------------------------------===// +template +struct QuantileListOperation : QuantileOperation { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (state.v.empty()) { finalize_data.ReturnNull(); return; } - using SAVE_TYPE = typename STATE::SaveType; + D_ASSERT(finalize_data.input.bind_data); auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &q = bind_data.quantiles[0]; - Interpolator interp(q, state.v.size(), false); - const auto med = interp.template Operation(state.v.data(), finalize_data.result); - MadAccessor accessor(med); - target = interp.template Operation(state.v.data(), finalize_data.result, accessor); + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + auto v_t = state.v.data(); + D_ASSERT(v_t); + + auto &entry = target; + entry.offset = ridx; + idx_t lower = 0; + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.v.size(), bind_data.desc); + interp.begin = lower; + rdata[ridx + q] = interp.template Operation(v_t, result); + lower = interp.FRN; + } + entry.length = bind_data.quantiles.size(); + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); } template static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, - idx_t ridx, const STATE *gstate) { - auto rdata = FlatVector::GetData(result); + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &list, + idx_t lidx, const STATE *gstate) { + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); QuantileIncluded included(fmask, dmask); const auto n = FrameSize(included, frames); + // Result is a constant LIST with a fixed length if (!n) { - auto &rmask = FlatVector::Validity(result); - rmask.Set(ridx, false); + auto &lmask = FlatVector::Validity(list); + lmask.Set(lidx, false); return; } - // Compute the median - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &quantile = bind_data.quantiles[0]; - MEDIAN_TYPE med; if (gstate && gstate->HasTrees()) { - med = gstate->template WindowScalar(data, frames, n, result, quantile); + gstate->GetWindowState().template WindowList(data, frames, n, list, lidx, bind_data); } else { - state.UpdateSkip(data, frames, included); - med = state.template WindowScalar(data, frames, n, result, quantile); + auto &window_state = state.GetOrCreateWindowState(); + window_state.UpdateSkip(data, frames, included); + window_state.template WindowList(data, frames, n, list, lidx, bind_data); + window_state.prevs = frames; } + } +}; - // Lazily initialise frame state - state.SetCount(frames.back().end - frames.front().start); - auto index2 = state.m.data(); - D_ASSERT(index2); +struct QuantileListFallback : QuantileOperation { + template + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { + state.AddElement(key, input_data); + } - // The replacement trick does not work on the second index because if - // the median has changed, the previous order is not correct. - // It is probably close, however, and so reuse is helpful. - auto &prevs = state.prevs; - ReuseIndexes(index2, frames, prevs); - std::partition(index2, index2 + state.count, included); + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } - Interpolator interp(quantile, n, false); + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); - // Compute mad from the second index - using ID = QuantileIndirect; - ID indirect(data); + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - using MAD = MadAccessor; - MAD mad(med); + D_ASSERT(state.v.data()); - using MadIndirect = QuantileComposed; - MadIndirect mad_indirect(mad, indirect); - rdata[ridx] = interp.template Operation(index2, result, mad_indirect); + auto &entry = target; + entry.offset = ridx; + idx_t lower = 0; + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.v.size(), bind_data.desc); + interp.begin = lower; + auto interpolation_result = interp.InterpolateInternal(state.v.data()); + CreateSortKeyHelpers::DecodeSortKey(interpolation_result, result, ridx + q, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + lower = interp.FRN; + } + entry.length = bind_data.quantiles.size(); - // Prev is used by both skip lists and increments - prevs = frames; + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); } }; -unique_ptr BindMedian(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); +//===--------------------------------------------------------------------===// +// Discrete Quantiles +//===--------------------------------------------------------------------===// +template +AggregateFunction GetDiscreteQuantileTemplated(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::INT8: + return OP::template GetFunction(type); + case PhysicalType::INT16: + return OP::template GetFunction(type); + case PhysicalType::INT32: + return OP::template GetFunction(type); + case PhysicalType::INT64: + return OP::template GetFunction(type); + case PhysicalType::INT128: + return OP::template GetFunction(type); + case PhysicalType::FLOAT: + return OP::template GetFunction(type); + case PhysicalType::DOUBLE: + return OP::template GetFunction(type); + case PhysicalType::INTERVAL: + return OP::template GetFunction(type); + case PhysicalType::VARCHAR: + return OP::template GetFunction(type); + default: + return OP::GetFallback(type); + } } -template -AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, - const LogicalType &target_type) { - using STATE = QuantileState; - using OP = MedianAbsoluteDeviationOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.bind = BindMedian; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = AggregateFunction::UnaryWindow; - fun.window_init = OP::template WindowInit; - return fun; +struct ScalarDiscreteQuantile { + template + static AggregateFunction GetFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); + fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::WindowInit; + return fun; + } + + static AggregateFunction GetFallback(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileScalarFallback; + + AggregateFunction fun( + {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, nullptr, + AggregateFunction::StateDestroy); + return fun; + } +}; + +template +static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); } -AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { +struct ListDiscreteQuantile { + template + static AggregateFunction GetFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(type, type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; + return fun; + } + + static AggregateFunction GetFallback(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileListFallback; + + AggregateFunction fun( + {type}, LogicalType::LIST(type), AggregateFunction::StateSize, + AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, + AggregateFunction::StateCombine, AggregateFunction::StateFinalize, + nullptr, nullptr, AggregateFunction::StateDestroy); + return fun; + } +}; + +AggregateFunction GetDiscreteQuantile(const LogicalType &type) { + return GetDiscreteQuantileTemplated(type); +} + +AggregateFunction GetDiscreteQuantileList(const LogicalType &type) { + return GetDiscreteQuantileTemplated(type); +} + +//===--------------------------------------------------------------------===// +// Continuous Quantiles +//===--------------------------------------------------------------------===// +template +AggregateFunction GetContinuousQuantileTemplated(const LogicalType &type) { switch (type.id()) { + case LogicalTypeId::TINYINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SMALLINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SQLNULL: + case LogicalTypeId::INTEGER: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::BIGINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::HUGEINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); case LogicalTypeId::FLOAT: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(type, type); + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: case LogicalTypeId::DOUBLE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(LogicalType::DOUBLE, LogicalType::DOUBLE); case LogicalTypeId::DECIMAL: switch (type.InternalType()) { case PhysicalType::INT16: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(type, type); case PhysicalType::INT32: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(type, type); case PhysicalType::INT64: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(type, type); case PhysicalType::INT128: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + return OP::template GetFunction(type, type); default: - throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); + throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); } - break; - case LogicalTypeId::DATE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); + return OP::template GetFunction(type, LogicalType::TIMESTAMP); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction( - type, LogicalType::INTERVAL); + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + return OP::template GetFunction(type, type); case LogicalTypeId::TIME: case LogicalTypeId::TIME_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); - + return OP::template GetFunction(type, type); default: - throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); + throw NotImplementedException("Unimplemented continuous quantile aggregate"); } } -unique_ptr BindMedianDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindMedian(context, function, arguments); +struct ScalarContinuousQuantile { + template + static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = + AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; + return fun; + } +}; + +struct ListContinuousQuantile { + template + static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(input_type, target_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; + return fun; + } +}; - function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type); - function.name = "median"; - function.serialize = QuantileBindData::SerializeDecimalDiscrete; - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return bind_data; +AggregateFunction GetContinuousQuantile(const LogicalType &type) { + return GetContinuousQuantileTemplated(type); } -unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); - function.name = "mad"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return BindMedian(context, function, arguments); +AggregateFunction GetContinuousQuantileList(const LogicalType &type) { + return GetContinuousQuantileTemplated(type); } +//===--------------------------------------------------------------------===// +// Quantile binding +//===--------------------------------------------------------------------===// static const Value &CheckQuantile(const Value &quantile_val) { if (quantile_val.IsNull()) { throw BinderException("QUANTILE parameter cannot be NULL"); @@ -1498,6 +578,9 @@ static const Value &CheckQuantile(const Value &quantile_val) { unique_ptr BindQuantile(ClientContext &context, AggregateFunction &function, vector> &arguments) { + if (arguments.size() < 2) { + throw BinderException("QUANTILE requires a range argument between [0, 1]"); + } if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); } @@ -1529,202 +612,235 @@ unique_ptr BindQuantile(ClientContext &context, AggregateFunction return make_uniq(quantiles); } -void BindQuantileInner(AggregateFunction &function, const LogicalType &type, QuantileSerializationType quantile_type) { - switch (quantile_type) { - case QuantileSerializationType::DECIMAL_DISCRETE: - function = GetDiscreteQuantileAggregateFunction(type); - function.serialize = QuantileBindData::SerializeDecimalDiscrete; - function.name = "quantile_disc"; - break; - case QuantileSerializationType::DECIMAL_DISCRETE_LIST: - function = GetDiscreteQuantileListAggregateFunction(type); - function.serialize = QuantileBindData::SerializeDecimalDiscreteList; - function.name = "quantile_disc"; - break; - case QuantileSerializationType::DECIMAL_CONTINUOUS: - function = GetContinuousQuantileAggregateFunction(type); - function.serialize = QuantileBindData::SerializeDecimalContinuous; - function.name = "quantile_cont"; - break; - case QuantileSerializationType::DECIMAL_CONTINUOUS_LIST: - function = GetContinuousQuantileListAggregateFunction(type); - function.serialize = QuantileBindData::SerializeDecimalContinuousList; - function.name = "quantile_cont"; - break; - case QuantileSerializationType::NON_DECIMAL: - throw SerializationException("NON_DECIMAL is not a valid quantile_type for BindQuantileInner"); +//===--------------------------------------------------------------------===// +// Function definitions +//===--------------------------------------------------------------------===// +static bool CanInterpolate(const LogicalType &type) { + if (type.HasAlias()) { + return false; + } + switch (type.id()) { + case LogicalTypeId::DECIMAL: + case LogicalTypeId::SQLNULL: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::BIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return true; + default: + return false; } - function.deserialize = QuantileBindData::Deserialize; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; } -unique_ptr BindDiscreteQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_DISCRETE); - return bind_data; -} +struct MedianFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = CanInterpolate(type) ? GetContinuousQuantile(type) : GetDiscreteQuantile(type); + fun.name = "median"; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + return fun; + } -unique_ptr BindDiscreteQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_DISCRETE_LIST); - return bind_data; -} + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); -unique_ptr BindContinuousQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_CONTINUOUS); - return bind_data; -} + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } -unique_ptr BindContinuousQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindQuantile(context, function, arguments); - BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_CONTINUOUS_LIST); - return bind_data; -} -static bool CanInterpolate(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::INTERVAL: - case LogicalTypeId::VARCHAR: - case LogicalTypeId::ANY: - return false; - default: - return true; + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); } -} +}; -AggregateFunction GetMedianAggregate(const LogicalType &type) { - auto fun = CanInterpolate(type) ? GetContinuousQuantileAggregateFunction(type) - : GetDiscreteQuantileAggregateFunction(type); - fun.bind = BindMedian; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - return fun; -} +struct DiscreteQuantileListFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantileList(type); + fun.name = "quantile_disc"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::LIST(LogicalType::DOUBLE)); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } -AggregateFunction GetDiscreteQuantileAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantileAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); -AggregateFunction GetDiscreteQuantileListAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantileListAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } -AggregateFunction GetContinuousQuantileAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantileAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return BindQuantile(context, function, arguments); + } +}; -AggregateFunction GetContinuousQuantileListAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantileListAggregateFunction(type); - fun.bind = BindQuantile; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} +struct DiscreteQuantileFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantile(type); + fun.name = "quantile_disc"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + auto &quantile_data = bind_data->Cast(); + + auto &input_type = function.arguments[0]; + if (quantile_data.quantiles.size() == 1) { + function = GetAggregate(input_type); + } else { + function = DiscreteQuantileListFunction::GetAggregate(input_type); + } + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return BindQuantile(context, function, arguments); + } +}; + +struct ContinuousQuantileFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantile(type); + fun.name = "quantile_cont"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type + : function.arguments[0]); + return BindQuantile(context, function, arguments); + } +}; + +struct ContinuousQuantileListFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantileList(type); + fun.name = "quantile_cont"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type + : function.arguments[0]); + return BindQuantile(context, function, arguments); + } +}; -AggregateFunction GetQuantileDecimalAggregate(const vector &arguments, const LogicalType &return_type, - bind_aggregate_function_t bind) { - AggregateFunction fun(arguments, return_type, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, bind); - fun.bind = bind; +template +AggregateFunction EmptyQuantileFunction(LogicalType input, LogicalType result, const LogicalType &extra_arg) { + AggregateFunction fun({std::move(input)}, std::move(result), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + OP::Bind); + if (extra_arg.id() != LogicalTypeId::INVALID) { + fun.arguments.push_back(extra_arg); + } fun.serialize = QuantileBindData::Serialize; - fun.deserialize = QuantileBindData::Deserialize; + fun.deserialize = OP::Deserialize; fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; return fun; } -vector GetQuantileTypes() { - return {LogicalType::TINYINT, LogicalType::SMALLINT, - LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::HUGEINT, LogicalType::FLOAT, - LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, - LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ, - LogicalType::INTERVAL, LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)}; -} - AggregateFunctionSet MedianFun::GetFunctions() { - AggregateFunctionSet median("median"); - median.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, BindMedianDecimal)); - for (const auto &type : GetQuantileTypes()) { - median.AddFunction(GetMedianAggregate(type)); - } - return median; + AggregateFunctionSet set("median"); + set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalTypeId::INVALID)); + return set; } AggregateFunctionSet QuantileDiscFun::GetFunctions() { - AggregateFunctionSet quantile_disc("quantile_disc"); - quantile_disc.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL, BindDiscreteQuantileDecimal)); - quantile_disc.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), BindDiscreteQuantileDecimalList)); - for (const auto &type : GetQuantileTypes()) { - quantile_disc.AddFunction(GetDiscreteQuantileAggregate(type)); - quantile_disc.AddFunction(GetDiscreteQuantileListAggregate(type)); - } - return quantile_disc; - // quantile + AggregateFunctionSet set("quantile_disc"); + set.AddFunction( + EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::DOUBLE)); + set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, + LogicalType::LIST(LogicalType::DOUBLE))); + // this function is here for deserialization - it cannot be called by users + set.AddFunction( + EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::INVALID)); + return set; +} + +vector GetContinuousQuantileTypes() { + return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ}; } AggregateFunctionSet QuantileContFun::GetFunctions() { AggregateFunctionSet quantile_cont("quantile_cont"); - quantile_cont.AddFunction(GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL, BindContinuousQuantileDecimal)); - quantile_cont.AddFunction( - GetQuantileDecimalAggregate({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), BindContinuousQuantileDecimalList)); - - for (const auto &type : GetQuantileTypes()) { - if (CanInterpolate(type)) { - quantile_cont.AddFunction(GetContinuousQuantileAggregate(type)); - quantile_cont.AddFunction(GetContinuousQuantileListAggregate(type)); - } + quantile_cont.AddFunction(EmptyQuantileFunction( + LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::DOUBLE)); + quantile_cont.AddFunction(EmptyQuantileFunction( + LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE))); + for (const auto &type : GetContinuousQuantileTypes()) { + quantile_cont.AddFunction(EmptyQuantileFunction(type, type, LogicalType::DOUBLE)); + quantile_cont.AddFunction( + EmptyQuantileFunction(type, type, LogicalType::LIST(LogicalType::DOUBLE))); } return quantile_cont; } -AggregateFunctionSet MadFun::GetFunctions() { - AggregateFunctionSet mad("mad"); - mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); - - const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, - LogicalType::TIME_TZ}; - for (const auto &type : MAD_TYPES) { - mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); - } - return mad; -} - } // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/nested/binned_histogram.cpp b/src/duckdb/src/core_functions/aggregate/nested/binned_histogram.cpp new file mode 100644 index 00000000..b639475a --- /dev/null +++ b/src/duckdb/src/core_functions/aggregate/nested/binned_histogram.cpp @@ -0,0 +1,405 @@ +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/core_functions/aggregate/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/core_functions/aggregate/histogram_helpers.hpp" +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +template +struct HistogramBinState { + using TYPE = T; + + unsafe_vector *bin_boundaries; + unsafe_vector *counts; + + void Initialize() { + bin_boundaries = nullptr; + counts = nullptr; + } + + void Destroy() { + if (bin_boundaries) { + delete bin_boundaries; + bin_boundaries = nullptr; + } + if (counts) { + delete counts; + counts = nullptr; + } + } + + bool IsSet() { + return bin_boundaries; + } + + template + void InitializeBins(Vector &bin_vector, idx_t count, idx_t pos, AggregateInputData &aggr_input) { + bin_boundaries = new unsafe_vector(); + counts = new unsafe_vector(); + UnifiedVectorFormat bin_data; + bin_vector.ToUnifiedFormat(count, bin_data); + auto bin_counts = UnifiedVectorFormat::GetData(bin_data); + auto bin_index = bin_data.sel->get_index(pos); + auto bin_list = bin_counts[bin_index]; + if (!bin_data.validity.RowIsValid(bin_index)) { + throw BinderException("Histogram bin list cannot be NULL"); + } + + auto &bin_child = ListVector::GetEntry(bin_vector); + auto bin_count = ListVector::GetListSize(bin_vector); + UnifiedVectorFormat bin_child_data; + auto extra_state = OP::CreateExtraState(bin_count); + OP::PrepareData(bin_child, bin_count, extra_state, bin_child_data); + + bin_boundaries->reserve(bin_list.length); + for (idx_t i = 0; i < bin_list.length; i++) { + auto bin_child_idx = bin_child_data.sel->get_index(bin_list.offset + i); + if (!bin_child_data.validity.RowIsValid(bin_child_idx)) { + throw BinderException("Histogram bin entry cannot be NULL"); + } + bin_boundaries->push_back(OP::template ExtractValue(bin_child_data, bin_list.offset + i, aggr_input)); + } + // sort the bin boundaries + std::sort(bin_boundaries->begin(), bin_boundaries->end()); + // ensure there are no duplicate bin boundaries + for (idx_t i = 1; i < bin_boundaries->size(); i++) { + if (Equals::Operation((*bin_boundaries)[i - 1], (*bin_boundaries)[i])) { + bin_boundaries->erase_at(i); + i--; + } + } + + counts->resize(bin_list.length + 1); + } +}; + +struct HistogramBinFunction { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.Destroy(); + } + + static bool IgnoreNull() { + return true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.bin_boundaries) { + // nothing to combine + return; + } + if (!target.bin_boundaries) { + // target does not have bin boundaries - copy everything over + target.bin_boundaries = new unsafe_vector(); + target.counts = new unsafe_vector(); + *target.bin_boundaries = *source.bin_boundaries; + *target.counts = *source.counts; + } else { + // both source and target have bin boundaries + if (*target.bin_boundaries != *source.bin_boundaries) { + throw NotImplementedException( + "Histogram - cannot combine histograms with different bin boundaries. " + "Bin boundaries must be the same for all histograms within the same group"); + } + if (target.counts->size() != source.counts->size()) { + throw InternalException("Histogram combine - bin boundaries are the same but counts are different"); + } + D_ASSERT(target.counts->size() == source.counts->size()); + for (idx_t bin_idx = 0; bin_idx < target.counts->size(); bin_idx++) { + (*target.counts)[bin_idx] += (*source.counts)[bin_idx]; + } + } + } +}; + +struct HistogramRange { + static constexpr bool EXACT = false; + + template + static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { + auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); + return UnsafeNumericCast(entry - bin_boundaries.begin()); + } +}; + +struct HistogramExact { + static constexpr bool EXACT = true; + + template + static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { + auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); + if (entry == bin_boundaries.end() || !(*entry == value)) { + // entry not found - return last bucket + return bin_boundaries.size(); + } + return UnsafeNumericCast(entry - bin_boundaries.begin()); + } +}; + +template +static void HistogramBinUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, + Vector &state_vector, idx_t count) { + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto &bin_vector = inputs[1]; + + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + + auto states = UnifiedVectorFormat::GetData *>(sdata); + auto data = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.IsSet()) { + state.template InitializeBins(bin_vector, count, i, aggr_input); + } + auto bin_entry = HIST::template GetBin(data[idx], *state.bin_boundaries); + ++(*state.counts)[bin_entry]; + } +} + +static bool SupportsOtherBucket(const LogicalType &type) { + if (type.HasAlias()) { + return false; + } + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + case LogicalTypeId::STRUCT: + case LogicalTypeId::LIST: + return true; + default: + return false; + } +} +static Value OtherBucketValue(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return Value::MaximumValue(type); + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + return Value::Infinity(type); + case LogicalTypeId::VARCHAR: + return Value(""); + case LogicalTypeId::BLOB: + return Value::BLOB(""); + case LogicalTypeId::STRUCT: { + // for structs we can set all child members to NULL + auto &child_types = StructType::GetChildTypes(type); + child_list_t child_list; + for (auto &child_type : child_types) { + child_list.push_back(make_pair(child_type.first, Value(child_type.second))); + } + return Value::STRUCT(std::move(child_list)); + } + case LogicalTypeId::LIST: + return Value::EMPTYLIST(ListType::GetChildType(type)); + default: + throw InternalException("Unsupported type for other bucket"); + } +} + +static void IsHistogramOtherBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input_type = args.data[0].GetType(); + if (!SupportsOtherBucket(input_type)) { + result.Reference(Value::BOOLEAN(false)); + return; + } + auto v = OtherBucketValue(input_type); + Vector ref(v); + VectorOperations::NotDistinctFrom(args.data[0], ref, result, args.size()); +} + +template +static void HistogramBinFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, + idx_t offset) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData *>(sdata); + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + bool supports_other_bucket = SupportsOtherBucket(MapType::KeyType(result.GetType())); + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.bin_boundaries) { + continue; + } + new_entries += state.bin_boundaries->size(); + if (state.counts->back() > 0 && supports_other_bucket) { + // overflow bucket has entries + new_entries++; + } + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &keys = MapVector::GetKeys(result); + auto &values = MapVector::GetValues(result); + auto list_entries = FlatVector::GetData(result); + auto count_entries = FlatVector::GetData(values); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.bin_boundaries) { + mask.SetInvalid(rid); + continue; + } + + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + for (idx_t bin_idx = 0; bin_idx < state.bin_boundaries->size(); bin_idx++) { + OP::template HistogramFinalize((*state.bin_boundaries)[bin_idx], keys, current_offset); + count_entries[current_offset] = (*state.counts)[bin_idx]; + current_offset++; + } + if (state.counts->back() > 0 && supports_other_bucket) { + // add overflow bucket ("others") + // set bin boundary to NULL for overflow bucket + keys.SetValue(current_offset, OtherBucketValue(keys.GetType())); + count_entries[current_offset] = state.counts->back(); + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); +} + +template +static AggregateFunction GetHistogramBinFunction(const LogicalType &type) { + using STATE_TYPE = HistogramBinState; + + const char *function_name = HIST::EXACT ? "histogram_exact" : "histogram"; + + auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); + return AggregateFunction( + function_name, {type, LogicalType::LIST(type)}, struct_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, HistogramBinUpdateFunction, + AggregateFunction::StateCombine, HistogramBinFinalizeFunction, nullptr, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetHistogramBinFunction(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::BOOL: + return GetHistogramBinFunction(type); + case PhysicalType::UINT8: + return GetHistogramBinFunction(type); + case PhysicalType::UINT16: + return GetHistogramBinFunction(type); + case PhysicalType::UINT32: + return GetHistogramBinFunction(type); + case PhysicalType::UINT64: + return GetHistogramBinFunction(type); + case PhysicalType::INT8: + return GetHistogramBinFunction(type); + case PhysicalType::INT16: + return GetHistogramBinFunction(type); + case PhysicalType::INT32: + return GetHistogramBinFunction(type); + case PhysicalType::INT64: + return GetHistogramBinFunction(type); + case PhysicalType::FLOAT: + return GetHistogramBinFunction(type); + case PhysicalType::DOUBLE: + return GetHistogramBinFunction(type); + case PhysicalType::VARCHAR: + return GetHistogramBinFunction(type); + default: + return GetHistogramBinFunction(type); + } +} + +template +unique_ptr HistogramBinBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + + function = GetHistogramBinFunction(arguments[0]->return_type); + return nullptr; +} + +AggregateFunction HistogramFun::BinnedHistogramFunction() { + return AggregateFunction("histogram", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, LogicalTypeId::MAP, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + HistogramBinBindFunction, nullptr); +} + +AggregateFunction HistogramExactFun::GetFunction() { + return AggregateFunction("histogram_exact", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, + LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + HistogramBinBindFunction, nullptr); +} + +ScalarFunction IsHistogramOtherBinFun::GetFunction() { + return ScalarFunction("is_histogram_other_bin", {LogicalType::ANY}, LogicalType::BOOLEAN, + IsHistogramOtherBinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp index 2f741434..447e8d0d 100644 --- a/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp +++ b/src/duckdb/src/core_functions/aggregate/nested/histogram.cpp @@ -1,57 +1,13 @@ #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/core_functions/aggregate/nested_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/core_functions/aggregate/histogram_helpers.hpp" +#include "duckdb/common/owning_string_map.hpp" namespace duckdb { -struct HistogramFunctor { - template > - static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { - auto states = (HistogramAggState **)sdata.data; - for (idx_t i = 0; i < count; i++) { - if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - state.hist = new MAP_TYPE(); - } - auto value = UnifiedVectorFormat::GetData(input_data); - (*state.hist)[value[input_data.sel->get_index(i)]]++; - } - } - } - - template - static Value HistogramFinalize(T first) { - return Value::CreateValue(first); - } -}; - -struct HistogramStringFunctor { - template > - static void HistogramUpdate(UnifiedVectorFormat &sdata, UnifiedVectorFormat &input_data, idx_t count) { - auto states = (HistogramAggState **)sdata.data; - auto input_strings = UnifiedVectorFormat::GetData(input_data); - for (idx_t i = 0; i < count; i++) { - if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - state.hist = new MAP_TYPE(); - } - (*state.hist)[input_strings[input_data.sel->get_index(i)].GetString()]++; - } - } - } - - template - static Value HistogramFinalize(T first) { - string_t value = first; - return Value::CreateValue(value); - } -}; - +template struct HistogramFunction { template static void Initialize(STATE &state) { @@ -59,7 +15,7 @@ struct HistogramFunction { } template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + static void Destroy(STATE &state, AggregateInputData &) { if (state.hist) { delete state.hist; } @@ -68,59 +24,97 @@ struct HistogramFunction { static bool IgnoreNull() { return true; } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.hist) { + return; + } + if (!target.hist) { + target.hist = MAP_TYPE::CreateEmpty(input_data.allocator); + } + for (auto &entry : *source.hist) { + (*target.hist)[entry.first] += entry.second; + } + } }; -template -static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, - idx_t count) { +template +struct DefaultMapType { + using MAP_TYPE = TYPE; - D_ASSERT(input_count == 1); + static TYPE *CreateEmpty(ArenaAllocator &) { + return new TYPE(); + } +}; - auto &input = inputs[0]; - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - UnifiedVectorFormat input_data; - input.ToUnifiedFormat(count, input_data); +template +struct StringMapType { + using MAP_TYPE = TYPE; - OP::template HistogramUpdate(sdata, input_data, count); -} + static TYPE *CreateEmpty(ArenaAllocator &allocator) { + return new TYPE(allocator); + } +}; -template -static void HistogramCombineFunction(Vector &state_vector, Vector &combined, AggregateInputData &, idx_t count) { +template +static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, + Vector &state_vector, idx_t count) { + + D_ASSERT(input_count == 1); + auto &input = inputs[0]; UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states_ptr = (HistogramAggState **)sdata.data; - auto combined_ptr = FlatVector::GetData *>(combined); + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + auto states = UnifiedVectorFormat::GetData *>(sdata); + auto input_values = UnifiedVectorFormat::GetData(input_data); for (idx_t i = 0; i < count; i++) { - auto &state = *states_ptr[sdata.sel->get_index(i)]; - if (!state.hist) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { continue; } - if (!combined_ptr[i]->hist) { - combined_ptr[i]->hist = new MAP_TYPE(); - } - D_ASSERT(combined_ptr[i]->hist); - D_ASSERT(state.hist); - for (auto &entry : *state.hist) { - (*combined_ptr[i]->hist)[entry.first] += entry.second; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + state.hist = MAP_TYPE::CreateEmpty(aggr_input.allocator); } + auto &input_value = input_values[idx]; + ++(*state.hist)[input_value]; } } template static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + using HIST_STATE = HistogramAggState; UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; + auto states = UnifiedVectorFormat::GetData(sdata); auto &mask = FlatVector::Validity(result); auto old_len = ListVector::GetListSize(result); - + idx_t new_entries = 0; + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + continue; + } + new_entries += state.hist->size(); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &keys = MapVector::GetKeys(result); + auto &values = MapVector::GetValues(result); + auto list_entries = FlatVector::GetData(result); + auto count_entries = FlatVector::GetData(values); + + idx_t current_offset = old_len; for (idx_t i = 0; i < count; i++) { const auto rid = i + offset; auto &state = *states[sdata.sel->get_index(i)]; @@ -129,135 +123,112 @@ static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData & continue; } + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; for (auto &entry : *state.hist) { - Value bucket_value = OP::template HistogramFinalize(entry.first); - auto count_value = Value::CreateValue(entry.second); - auto struct_value = - Value::STRUCT({std::make_pair("key", bucket_value), std::make_pair("value", count_value)}); - ListVector::PushBack(result, struct_value); + OP::template HistogramFinalize(entry.first, keys, current_offset); + count_entries[current_offset] = entry.second; + current_offset++; } - - auto list_struct_data = ListVector::GetData(result); - list_struct_data[rid].length = ListVector::GetListSize(result) - old_len; - list_struct_data[rid].offset = old_len; - old_len += list_struct_data[rid].length; + list_entry.length = current_offset - list_entry.offset; } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); result.Verify(count); } -unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - - D_ASSERT(arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::LIST || - arguments[0]->return_type.id() == LogicalTypeId::STRUCT || - arguments[0]->return_type.id() == LogicalTypeId::MAP) { - throw NotImplementedException("Unimplemented type for histogram %s", arguments[0]->return_type.ToString()); - } - auto child_type = function.arguments[0].id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : function.arguments[0]; - auto struct_type = LogicalType::MAP(child_type, LogicalType::UBIGINT); - - function.return_type = struct_type; - return make_uniq(function.return_type); -} - -template > +template static AggregateFunction GetHistogramFunction(const LogicalType &type) { + using STATE_TYPE = HistogramAggState; + using HIST_FUNC = HistogramFunction; + + auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); + return AggregateFunction( + "histogram", {type}, struct_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, HistogramUpdateFunction, + AggregateFunction::StateCombine, HistogramFinalizeFunction, nullptr, + nullptr, AggregateFunction::StateDestroy); +} - using STATE_TYPE = HistogramAggState; - - return AggregateFunction("histogram", {type}, LogicalTypeId::MAP, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - HistogramUpdateFunction, HistogramCombineFunction, - HistogramFinalizeFunction, nullptr, HistogramBindFunction, - AggregateFunction::StateDestroy); +template +AggregateFunction GetMapTypeInternal(const LogicalType &type) { + return GetHistogramFunction(type); } template AggregateFunction GetMapType(const LogicalType &type) { if (IS_ORDERED) { - return GetHistogramFunction(type); + return GetMapTypeInternal>>(type); + } + return GetMapTypeInternal>>(type); +} + +template +AggregateFunction GetStringMapType(const LogicalType &type) { + if (IS_ORDERED) { + return GetMapTypeInternal>>(type); + } else { + return GetMapTypeInternal>>(type); } - return GetHistogramFunction>(type); } template AggregateFunction GetHistogramFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: + switch (type.InternalType()) { + case PhysicalType::BOOL: return GetMapType(type); - case LogicalTypeId::UTINYINT: + case PhysicalType::UINT8: return GetMapType(type); - case LogicalTypeId::USMALLINT: + case PhysicalType::UINT16: return GetMapType(type); - case LogicalTypeId::UINTEGER: + case PhysicalType::UINT32: return GetMapType(type); - case LogicalTypeId::UBIGINT: + case PhysicalType::UINT64: return GetMapType(type); - case LogicalTypeId::TINYINT: + case PhysicalType::INT8: return GetMapType(type); - case LogicalTypeId::SMALLINT: + case PhysicalType::INT16: return GetMapType(type); - case LogicalTypeId::INTEGER: + case PhysicalType::INT32: return GetMapType(type); - case LogicalTypeId::BIGINT: + case PhysicalType::INT64: return GetMapType(type); - case LogicalTypeId::FLOAT: + case PhysicalType::FLOAT: return GetMapType(type); - case LogicalTypeId::DOUBLE: + case PhysicalType::DOUBLE: return GetMapType(type); - case LogicalTypeId::TIMESTAMP: - return GetMapType(type); - case LogicalTypeId::TIMESTAMP_TZ: - return GetMapType(type); - case LogicalTypeId::TIMESTAMP_SEC: - return GetMapType(type); - case LogicalTypeId::TIMESTAMP_MS: - return GetMapType(type); - case LogicalTypeId::TIMESTAMP_NS: - return GetMapType(type); - case LogicalTypeId::TIME: - return GetMapType(type); - case LogicalTypeId::TIME_TZ: - return GetMapType(type); - case LogicalTypeId::DATE: - return GetMapType(type); - case LogicalTypeId::ANY: - return GetMapType(type); + case PhysicalType::VARCHAR: + return GetStringMapType(type); default: - throw InternalException("Unimplemented histogram aggregate"); + return GetStringMapType(type); + } +} + +template +unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + + D_ASSERT(arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); } + function = GetHistogramFunction(arguments[0]->return_type); + return make_uniq(function.return_type); } AggregateFunctionSet HistogramFun::GetFunctions() { AggregateFunctionSet fun; - fun.AddFunction(GetHistogramFunction<>(LogicalType::BOOLEAN)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UTINYINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::USMALLINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UINTEGER)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::UBIGINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TINYINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::SMALLINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::INTEGER)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::BIGINT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::FLOAT)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::DOUBLE)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_TZ)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_S)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_MS)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIMESTAMP_NS)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME_TZ)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::DATE)); - fun.AddFunction(GetHistogramFunction<>(LogicalType::ANY_PARAMS(LogicalType::VARCHAR))); + AggregateFunction histogram_function("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, HistogramBindFunction, nullptr); + fun.AddFunction(HistogramFun::BinnedHistogramFunction()); + fun.AddFunction(histogram_function); return fun; } AggregateFunction HistogramFun::GetHistogramUnorderedMap(LogicalType &type) { - const auto &const_type = type; - return GetHistogramFunction(const_type); + return AggregateFunction("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, HistogramBindFunction, nullptr); } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/function_list.cpp b/src/duckdb/src/core_functions/function_list.cpp index e62330b5..c01d3e85 100644 --- a/src/duckdb/src/core_functions/function_list.cpp +++ b/src/duckdb/src/core_functions/function_list.cpp @@ -50,23 +50,28 @@ namespace duckdb { static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(FactorialOperatorFun), DUCKDB_SCALAR_FUNCTION_SET(BitwiseAndFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAnyFunAlias), DUCKDB_SCALAR_FUNCTION(PowOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListInnerProductFunAlias), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListNegativeInnerProductFunAlias), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDistanceFunAlias), DUCKDB_SCALAR_FUNCTION_SET(LeftShiftFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineSimilarityFunAlias), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineDistanceFunAlias), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias2), DUCKDB_SCALAR_FUNCTION_SET(RightShiftFun), DUCKDB_SCALAR_FUNCTION_SET(AbsOperatorFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias), DUCKDB_SCALAR_FUNCTION_ALIAS(PowOperatorFunAlias), DUCKDB_SCALAR_FUNCTION(StartsWithOperatorFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AbsFun), DUCKDB_SCALAR_FUNCTION(AcosFun), + DUCKDB_SCALAR_FUNCTION(AcoshFun), DUCKDB_SCALAR_FUNCTION_SET(AgeFun), DUCKDB_SCALAR_FUNCTION_ALIAS(AggregateFun), DUCKDB_SCALAR_FUNCTION(AliasFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ApplyFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ApproxCountDistinctFun), + DUCKDB_AGGREGATE_FUNCTION(ApproxCountDistinctFun), DUCKDB_AGGREGATE_FUNCTION_SET(ApproxQuantileFun), + DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), @@ -77,6 +82,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggrFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggregateFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayApplyFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineDistanceFun), DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineSimilarityFun), DUCKDB_SCALAR_FUNCTION_SET(ArrayCrossProductFun), DUCKDB_SCALAR_FUNCTION_SET(ArrayDistanceFun), @@ -84,7 +90,11 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayDotProductFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayFilterFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayGradeUpFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAllFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAnyFun), DUCKDB_SCALAR_FUNCTION_SET(ArrayInnerProductFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayNegativeDotProductFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayNegativeInnerProductFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayReduceFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayReverseSortFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySliceFun), @@ -94,8 +104,10 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(ArrayValueFun), DUCKDB_SCALAR_FUNCTION(ASCIIFun), DUCKDB_SCALAR_FUNCTION(AsinFun), + DUCKDB_SCALAR_FUNCTION(AsinhFun), DUCKDB_SCALAR_FUNCTION(AtanFun), DUCKDB_SCALAR_FUNCTION(Atan2Fun), + DUCKDB_SCALAR_FUNCTION(AtanhFun), DUCKDB_AGGREGATE_FUNCTION_SET(AvgFun), DUCKDB_SCALAR_FUNCTION_SET(BarFun), DUCKDB_SCALAR_FUNCTION_ALIAS(Base64Fun), @@ -109,6 +121,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), + DUCKDB_SCALAR_FUNCTION(CanCastImplicitlyFun), DUCKDB_SCALAR_FUNCTION(CardinalityFun), DUCKDB_SCALAR_FUNCTION(CbrtFun), DUCKDB_SCALAR_FUNCTION_SET(CeilFun), @@ -117,6 +130,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(ChrFun), DUCKDB_AGGREGATE_FUNCTION(CorrFun), DUCKDB_SCALAR_FUNCTION(CosFun), + DUCKDB_SCALAR_FUNCTION(CoshFun), DUCKDB_SCALAR_FUNCTION(CotFun), DUCKDB_AGGREGATE_FUNCTION(CovarPopFun), DUCKDB_AGGREGATE_FUNCTION(CovarSampFun), @@ -157,6 +171,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(EpochMsFun), DUCKDB_SCALAR_FUNCTION_SET(EpochNsFun), DUCKDB_SCALAR_FUNCTION_SET(EpochUsFun), + DUCKDB_SCALAR_FUNCTION_SET(EquiWidthBinsFun), DUCKDB_SCALAR_FUNCTION_SET(EraFun), DUCKDB_SCALAR_FUNCTION(ErrorFun), DUCKDB_SCALAR_FUNCTION(EvenFun), @@ -189,9 +204,11 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(HashFun), DUCKDB_SCALAR_FUNCTION_SET(HexFun), DUCKDB_AGGREGATE_FUNCTION_SET(HistogramFun), + DUCKDB_AGGREGATE_FUNCTION(HistogramExactFun), DUCKDB_SCALAR_FUNCTION_SET(HoursFun), DUCKDB_SCALAR_FUNCTION(InSearchPathFun), DUCKDB_SCALAR_FUNCTION(InstrFun), + DUCKDB_SCALAR_FUNCTION(IsHistogramOtherBinFun), DUCKDB_SCALAR_FUNCTION_SET(IsFiniteFun), DUCKDB_SCALAR_FUNCTION_SET(IsInfiniteFun), DUCKDB_SCALAR_FUNCTION_SET(IsNanFun), @@ -216,13 +233,18 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_ALIAS(ListAggrFun), DUCKDB_SCALAR_FUNCTION(ListAggregateFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ListApplyFun), + DUCKDB_SCALAR_FUNCTION_SET(ListCosineDistanceFun), DUCKDB_SCALAR_FUNCTION_SET(ListCosineSimilarityFun), DUCKDB_SCALAR_FUNCTION_SET(ListDistanceFun), DUCKDB_SCALAR_FUNCTION(ListDistinctFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDotProductFun), DUCKDB_SCALAR_FUNCTION(ListFilterFun), DUCKDB_SCALAR_FUNCTION_SET(ListGradeUpFun), + DUCKDB_SCALAR_FUNCTION(ListHasAllFun), + DUCKDB_SCALAR_FUNCTION(ListHasAnyFun), DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListNegativeDotProductFun), + DUCKDB_SCALAR_FUNCTION_SET(ListNegativeInnerProductFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), DUCKDB_SCALAR_FUNCTION(ListReduceFun), DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), @@ -244,6 +266,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), DUCKDB_SCALAR_FUNCTION(MapFun), DUCKDB_SCALAR_FUNCTION(MapConcatFun), + DUCKDB_SCALAR_FUNCTION(MapContainsFun), DUCKDB_SCALAR_FUNCTION(MapEntriesFun), DUCKDB_SCALAR_FUNCTION(MapExtractFun), DUCKDB_SCALAR_FUNCTION(MapFromEntriesFun), @@ -251,10 +274,8 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(MapValuesFun), DUCKDB_AGGREGATE_FUNCTION_SET(MaxFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MaxByFun), - DUCKDB_SCALAR_FUNCTION(MD5Fun), - DUCKDB_SCALAR_FUNCTION(MD5NumberFun), - DUCKDB_SCALAR_FUNCTION(MD5NumberLowerFun), - DUCKDB_SCALAR_FUNCTION(MD5NumberUpperFun), + DUCKDB_SCALAR_FUNCTION_SET(MD5Fun), + DUCKDB_SCALAR_FUNCTION_SET(MD5NumberFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MeanFun), DUCKDB_AGGREGATE_FUNCTION_SET(MedianFun), DUCKDB_SCALAR_FUNCTION_SET(MicrosecondsFun), @@ -267,6 +288,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_AGGREGATE_FUNCTION_SET(ModeFun), DUCKDB_SCALAR_FUNCTION_SET(MonthFun), DUCKDB_SCALAR_FUNCTION_SET(MonthNameFun), + DUCKDB_SCALAR_FUNCTION_SET(NanosecondsFun), DUCKDB_SCALAR_FUNCTION_SET(NextAfterFun), DUCKDB_SCALAR_FUNCTION_ALIAS(NowFun), DUCKDB_SCALAR_FUNCTION_ALIAS(OrdFun), @@ -313,10 +335,12 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_AGGREGATE_FUNCTION(StandardErrorOfTheMeanFun), DUCKDB_SCALAR_FUNCTION(SetBitFun), DUCKDB_SCALAR_FUNCTION(SetseedFun), - DUCKDB_SCALAR_FUNCTION(SHA256Fun), + DUCKDB_SCALAR_FUNCTION_SET(SHA1Fun), + DUCKDB_SCALAR_FUNCTION_SET(SHA256Fun), DUCKDB_SCALAR_FUNCTION_SET(SignFun), DUCKDB_SCALAR_FUNCTION_SET(SignBitFun), DUCKDB_SCALAR_FUNCTION(SinFun), + DUCKDB_SCALAR_FUNCTION(SinhFun), DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), DUCKDB_SCALAR_FUNCTION(SqrtFun), @@ -340,6 +364,7 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), DUCKDB_SCALAR_FUNCTION(TanFun), + DUCKDB_SCALAR_FUNCTION(TanhFun), DUCKDB_SCALAR_FUNCTION_SET(TimeBucketFun), DUCKDB_SCALAR_FUNCTION(TimeTZSortKeyFun), DUCKDB_SCALAR_FUNCTION_SET(TimezoneFun), @@ -378,6 +403,8 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION(UnionTagFun), DUCKDB_SCALAR_FUNCTION(UnionValueFun), DUCKDB_SCALAR_FUNCTION(UnpivotListFun), + DUCKDB_SCALAR_FUNCTION(UrlDecodeFun), + DUCKDB_SCALAR_FUNCTION(UrlEncodeFun), DUCKDB_SCALAR_FUNCTION(UUIDFun), DUCKDB_AGGREGATE_FUNCTION(VarPopFun), DUCKDB_AGGREGATE_FUNCTION(VarSampFun), diff --git a/src/duckdb/src/core_functions/lambda_functions.cpp b/src/duckdb/src/core_functions/lambda_functions.cpp index ee78be58..9a3a0310 100644 --- a/src/duckdb/src/core_functions/lambda_functions.cpp +++ b/src/duckdb/src/core_functions/lambda_functions.cpp @@ -161,7 +161,6 @@ struct ListFilterFunctor { }; vector LambdaFunctions::GetColumnInfo(DataChunk &args, const idx_t row_count) { - vector data; // skip the input list and then insert all remaining input vectors for (idx_t i = 1; i < args.ColumnCount(); i++) { @@ -172,8 +171,7 @@ vector LambdaFunctions::GetColumnInfo(DataChunk &ar } vector> -LambdaFunctions::GetInconstantColumnInfo(vector &data) { - +LambdaFunctions::GetMutableColumnInfo(vector &data) { vector> inconstant_info; for (auto &entry : data) { if (entry.vector.get().GetVectorType() != VectorType::CONSTANT_VECTOR) { @@ -246,8 +244,8 @@ void ListLambdaBindData::Serialize(Serializer &serializer, const optional_ptr ListLambdaBindData::Deserialize(Deserializer &deserializer, ScalarFunction &) { auto return_type = deserializer.ReadProperty(100, "return_type"); - auto lambda_expr = - deserializer.ReadPropertyWithDefault>(101, "lambda_expr", unique_ptr()); + auto lambda_expr = deserializer.ReadPropertyWithExplicitDefault>(101, "lambda_expr", + unique_ptr()); auto has_index = deserializer.ReadProperty(102, "has_index"); return make_uniq(return_type, std::move(lambda_expr), has_index); } @@ -290,7 +288,7 @@ void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { } auto result_entries = FlatVector::GetData(result); - auto inconstant_column_infos = LambdaFunctions::GetInconstantColumnInfo(info.column_infos); + auto mutable_column_infos = LambdaFunctions::GetMutableColumnInfo(info.column_infos); // special-handling for the child_vector auto child_vector_size = ListVector::GetListSize(args.data[0]); @@ -347,7 +345,7 @@ void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { // FIXME: reuse same selection vector for inconstant rows // adjust indexes for slicing child_info.sel.set_index(elem_cnt, list_entry.offset + child_idx); - for (auto &entry : inconstant_column_infos) { + for (auto &entry : mutable_column_infos) { entry.get().sel.set_index(elem_cnt, row_idx); } diff --git a/src/duckdb/src/core_functions/scalar/array/array_functions.cpp b/src/duckdb/src/core_functions/scalar/array/array_functions.cpp index c5e7189b..347ffcbd 100644 --- a/src/duckdb/src/core_functions/scalar/array/array_functions.cpp +++ b/src/duckdb/src/core_functions/scalar/array/array_functions.cpp @@ -1,182 +1,135 @@ #include "duckdb/core_functions/scalar/array_functions.hpp" -#include +#include "duckdb/core_functions/array_kernels.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { -//------------------------------------------------------------------------------ -// Functors -//------------------------------------------------------------------------------ +static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { -struct InnerProductOp { - static constexpr const char *NAME = "array_inner_product"; + const auto lhs_is_param = arguments[0]->HasParameter(); + const auto rhs_is_param = arguments[1]->HasParameter(); - template - inline static TYPE *GetResultData(Vector &result_vec) { - return FlatVector::GetData(result_vec); + if (lhs_is_param && rhs_is_param) { + throw ParameterNotResolvedException(); } - template - inline static void Operation(TYPE *l_data, idx_t l_idx, TYPE *r_data, idx_t r_idx, TYPE *result_data, - idx_t result_idx, idx_t size) { + const auto &lhs_type = arguments[0]->return_type; + const auto &rhs_type = arguments[1]->return_type; - TYPE inner_product = 0; + bound_function.arguments[0] = lhs_is_param ? rhs_type : lhs_type; + bound_function.arguments[1] = rhs_is_param ? lhs_type : rhs_type; - auto l_ptr = l_data + (l_idx * size); - auto r_ptr = r_data + (r_idx * size); - - for (idx_t elem_idx = 0; elem_idx < size; elem_idx++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - inner_product += x * y; - } - - result_data[result_idx] = inner_product; + if (bound_function.arguments[0].id() != LogicalTypeId::ARRAY || + bound_function.arguments[1].id() != LogicalTypeId::ARRAY) { + throw InvalidInputException( + StringUtil::Format("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name)); } -}; -struct DistanceOp { - static constexpr const char *NAME = "array_distance"; + const auto lhs_size = ArrayType::GetSize(bound_function.arguments[0]); + const auto rhs_size = ArrayType::GetSize(bound_function.arguments[1]); - template - inline static TYPE *GetResultData(Vector &result_vec) { - return FlatVector::GetData(result_vec); + if (lhs_size != rhs_size) { + throw BinderException("%s: Array arguments must be of the same size", bound_function.name); } - template - inline static void Operation(TYPE *l_data, idx_t l_idx, TYPE *r_data, idx_t r_idx, TYPE *result_data, - idx_t result_idx, idx_t size) { - - TYPE distance = 0; - - auto l_ptr = l_data + (l_idx * size); - auto r_ptr = r_data + (r_idx * size); - - for (idx_t elem_idx = 0; elem_idx < size; elem_idx++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - auto diff = x - y; - distance += diff * diff; - } + const auto &lhs_element_type = ArrayType::GetChildType(bound_function.arguments[0]); + const auto &rhs_element_type = ArrayType::GetChildType(bound_function.arguments[1]); - result_data[result_idx] = std::sqrt(distance); + // Resolve common type + LogicalType common_type; + if (!LogicalType::TryGetMaxLogicalType(context, lhs_element_type, rhs_element_type, common_type)) { + throw BinderException("%s: Cannot infer common element type (left = '%s', right = '%s')", bound_function.name, + lhs_element_type.ToString(), rhs_element_type.ToString()); } -}; - -struct CosineSimilarityOp { - static constexpr const char *NAME = "array_cosine_similarity"; - template - inline static TYPE *GetResultData(Vector &result_vec) { - return FlatVector::GetData(result_vec); + // Ensure it is float or double + if (common_type.id() != LogicalTypeId::FLOAT && common_type.id() != LogicalTypeId::DOUBLE) { + throw BinderException("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name); } - template - inline static void Operation(TYPE *l_data, idx_t l_idx, TYPE *r_data, idx_t r_idx, TYPE *result_data, - idx_t result_idx, idx_t size) { - - TYPE distance = 0; - TYPE norm_l = 0; - TYPE norm_r = 0; - - auto l_ptr = l_data + (l_idx * size); - auto r_ptr = r_data + (r_idx * size); - - for (idx_t i = 0; i < size; i++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - distance += x * y; - norm_l += x * x; - norm_r += y * y; - } + // The important part is just that we resolve the size of the input arrays + bound_function.arguments[0] = LogicalType::ARRAY(common_type, lhs_size); + bound_function.arguments[1] = LogicalType::ARRAY(common_type, rhs_size); - auto similarity = distance / (std::sqrt(norm_l) * std::sqrt(norm_r)); + return nullptr; +} - // clamp to [-1, 1] to avoid floating point errors - result_data[result_idx] = std::max(static_cast(-1), std::min(similarity, static_cast(1))); - } -}; +//------------------------------------------------------------------------------ +// Element-wise combine functions +//------------------------------------------------------------------------------ +// Given two arrays of the same size, combine their elements into a single array +// of the same size as the input arrays. struct CrossProductOp { - static constexpr const char *NAME = "array_cross_product"; - template - inline static TYPE *GetResultData(Vector &result_vec) { - // Since we return an array here, we need to get the data pointer of the child - auto &child = ArrayVector::GetEntry(result_vec); - return FlatVector::GetData(child); - } - - template - inline static void Operation(TYPE *l_data, idx_t l_idx, TYPE *r_data, idx_t r_idx, TYPE *result_data, - idx_t result_idx, idx_t size) { + static void Operation(const TYPE *lhs_data, const TYPE *rhs_data, TYPE *res_data, idx_t size) { D_ASSERT(size == 3); - auto l_child_idx = l_idx * size; - auto r_child_idx = r_idx * size; - auto res_child_idx = result_idx * size; - - auto lx = l_data[l_child_idx + 0]; - auto ly = l_data[l_child_idx + 1]; - auto lz = l_data[l_child_idx + 2]; + auto lx = lhs_data[0]; + auto ly = lhs_data[1]; + auto lz = lhs_data[2]; - auto rx = r_data[r_child_idx + 0]; - auto ry = r_data[r_child_idx + 1]; - auto rz = r_data[r_child_idx + 2]; + auto rx = rhs_data[0]; + auto ry = rhs_data[1]; + auto rz = rhs_data[2]; - result_data[res_child_idx + 0] = ly * rz - lz * ry; - result_data[res_child_idx + 1] = lz * rx - lx * rz; - result_data[res_child_idx + 2] = lx * ry - ly * rx; + res_data[0] = ly * rz - lz * ry; + res_data[1] = lz * rx - lx * rz; + res_data[2] = lx * ry - ly * rx; } }; -//------------------------------------------------------------------------------ -// Generic Execute and Bind -//------------------------------------------------------------------------------ -// This is a generic executor function for fast binary math operations on -// real-valued arrays. Array elements are assumed to be either FLOAT or DOUBLE, -// and cannot be null. (although the array itself can be null). -// In the future we could extend this further to be truly generic and handle -// other types, unary/ternary operations and/or nulls. - -template -static inline void ArrayGenericBinaryExecute(Vector &left, Vector &right, Vector &result, idx_t size, idx_t count) { +template +static void ArrayFixedCombine(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; - auto &left_child = ArrayVector::GetEntry(left); - auto &right_child = ArrayVector::GetEntry(right); + const auto count = args.size(); + auto &lhs_child = ArrayVector::GetEntry(args.data[0]); + auto &rhs_child = ArrayVector::GetEntry(args.data[1]); + auto &res_child = ArrayVector::GetEntry(result); - auto &left_child_validity = FlatVector::Validity(left_child); - auto &right_child_validity = FlatVector::Validity(right_child); + const auto &lhs_child_validity = FlatVector::Validity(lhs_child); + const auto &rhs_child_validity = FlatVector::Validity(rhs_child); - UnifiedVectorFormat left_format; - UnifiedVectorFormat right_format; + UnifiedVectorFormat lhs_format; + UnifiedVectorFormat rhs_format; - left.ToUnifiedFormat(count, left_format); - right.ToUnifiedFormat(count, right_format); + args.data[0].ToUnifiedFormat(count, lhs_format); + args.data[1].ToUnifiedFormat(count, rhs_format); - auto left_data = FlatVector::GetData(left_child); - auto right_data = FlatVector::GetData(right_child); - auto result_data = OP::template GetResultData(result); + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); + auto res_data = FlatVector::GetData(res_child); for (idx_t i = 0; i < count; i++) { - auto left_idx = left_format.sel->get_index(i); - auto right_idx = right_format.sel->get_index(i); + const auto lhs_idx = lhs_format.sel->get_index(i); + const auto rhs_idx = rhs_format.sel->get_index(i); - if (!left_format.validity.RowIsValid(left_idx) || !right_format.validity.RowIsValid(right_idx)) { + if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { FlatVector::SetNull(result, i, true); continue; } - auto left_offset = left_idx * size; - if (!left_child_validity.CheckAllValid(left_offset + size, left_offset)) { - throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", OP::NAME)); + const auto left_offset = lhs_idx * N; + if (!lhs_child_validity.CheckAllValid(left_offset + N, left_offset)) { + throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); } - auto right_offset = right_idx * size; - if (!right_child_validity.CheckAllValid(right_offset + size, right_offset)) { - throw InvalidInputException(StringUtil::Format("%s: right argument can not contain NULL values", OP::NAME)); + const auto right_offset = rhs_idx * N; + if (!rhs_child_validity.CheckAllValid(right_offset + N, right_offset)) { + throw InvalidInputException( + StringUtil::Format("%s: right argument can not contain NULL values", func_name)); } + const auto result_offset = i * N; - OP::template Operation(left_data, left_idx, right_data, right_idx, result_data, i, size); + const auto lhs_data_ptr = lhs_data + left_offset; + const auto rhs_data_ptr = rhs_data + right_offset; + const auto res_data_ptr = res_data + result_offset; + + OP::Operation(lhs_data_ptr, rhs_data_ptr, res_data_ptr, N); } if (count == 1) { @@ -184,100 +137,123 @@ static inline void ArrayGenericBinaryExecute(Vector &left, Vector &right, Vector } } -template -static void ArrayGenericBinaryFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto size = ArrayType::GetSize(args.data[0].GetType()); - auto child_type = ArrayType::GetChildType(args.data[0].GetType()); - switch (child_type.id()) { - case LogicalTypeId::DOUBLE: - ArrayGenericBinaryExecute(args.data[0], args.data[1], result, size, args.size()); - break; - case LogicalTypeId::FLOAT: - ArrayGenericBinaryExecute(args.data[0], args.data[1], result, size, args.size()); - break; - default: - throw NotImplementedException(StringUtil::Format("%s: Unsupported element type", OP::NAME)); - } -} +//------------------------------------------------------------------------------ +// Generic "fold" function +//------------------------------------------------------------------------------ +// Given two arrays, combine and reduce their elements into a single scalar value. -template -static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +template +static void ArrayGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; - // construct return type - auto &left_type = arguments[0]->return_type; - auto &right_type = arguments[1]->return_type; + const auto count = args.size(); + auto &lhs_child = ArrayVector::GetEntry(args.data[0]); + auto &rhs_child = ArrayVector::GetEntry(args.data[1]); - // mystery to me how anything non-array could ever end up here but it happened - if (left_type.id() != LogicalTypeId::ARRAY || right_type.id() != LogicalTypeId::ARRAY) { - throw InvalidInputException(StringUtil::Format("%s: Arguments must be arrays of FLOAT or DOUBLE", OP::NAME)); - } + const auto &lhs_child_validity = FlatVector::Validity(lhs_child); + const auto &rhs_child_validity = FlatVector::Validity(rhs_child); - auto left_size = ArrayType::GetSize(left_type); - auto right_size = ArrayType::GetSize(right_type); - if (left_size != right_size) { - throw InvalidInputException(StringUtil::Format("%s: Array arguments must be of the same size", OP::NAME)); - } - auto size = left_size; + UnifiedVectorFormat lhs_format; + UnifiedVectorFormat rhs_format; - auto child_type = - LogicalType::MaxLogicalType(context, ArrayType::GetChildType(left_type), ArrayType::GetChildType(right_type)); - if (child_type != LogicalTypeId::FLOAT && child_type != LogicalTypeId::DOUBLE) { - throw InvalidInputException( - StringUtil::Format("%s: Array arguments must be of type FLOAT or DOUBLE", OP::NAME)); - } + args.data[0].ToUnifiedFormat(count, lhs_format); + args.data[1].ToUnifiedFormat(count, rhs_format); - // the important part here is that we resolve the array size - auto array_type = LogicalType::ARRAY(child_type, size); + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); + auto res_data = FlatVector::GetData(result); - bound_function.arguments[0] = array_type; - bound_function.arguments[1] = array_type; - bound_function.return_type = child_type; + const auto array_size = ArrayType::GetSize(args.data[0].GetType()); + D_ASSERT(array_size == ArrayType::GetSize(args.data[1].GetType())); - return nullptr; -} + for (idx_t i = 0; i < count; i++) { + const auto lhs_idx = lhs_format.sel->get_index(i); + const auto rhs_idx = rhs_format.sel->get_index(i); + + if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + + const auto left_offset = lhs_idx * array_size; + if (!lhs_child_validity.CheckAllValid(left_offset + array_size, left_offset)) { + throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); + } + + const auto right_offset = rhs_idx * array_size; + if (!rhs_child_validity.CheckAllValid(right_offset + array_size, right_offset)) { + throw InvalidInputException( + StringUtil::Format("%s: right argument can not contain NULL values", func_name)); + } -template -static inline void ArrayFixedBinaryFunction(DataChunk &args, ExpressionState &, Vector &result) { - ArrayGenericBinaryExecute(args.data[0], args.data[1], result, N, args.size()); + const auto lhs_data_ptr = lhs_data + left_offset; + const auto rhs_data_ptr = rhs_data + right_offset; + + res_data[i] = OP::Operation(lhs_data_ptr, rhs_data_ptr, array_size); + } + + if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } } //------------------------------------------------------------------------------ // Function Registration //------------------------------------------------------------------------------ - // Note: In the future we could add a wrapper with a non-type template parameter to specialize for specific array sizes // e.g. 256, 512, 1024, 2048 etc. which may allow the compiler to vectorize the loop better. Perhaps something for an // extension. +template +static void AddArrayFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { + const auto array = LogicalType::ARRAY(type, optional_idx()); + if (type.id() == LogicalTypeId::FLOAT) { + set.AddFunction(ScalarFunction({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind)); + } else if (type.id() == LogicalTypeId::DOUBLE) { + set.AddFunction(ScalarFunction({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind)); + } else { + throw NotImplementedException("Array function not implemented for type %s", type.ToString()); + } +} + +ScalarFunctionSet ArrayDistanceFun::GetFunctions() { + ScalarFunctionSet set("array_distance"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + ScalarFunctionSet ArrayInnerProductFun::GetFunctions() { ScalarFunctionSet set("array_inner_product"); - // Generic array inner product function for (auto &type : LogicalType::Real()) { - set.AddFunction(ScalarFunction({LogicalType::ARRAY(type), LogicalType::ARRAY(type)}, type, - ArrayGenericBinaryFunction, - ArrayGenericBinaryBind)); + AddArrayFoldFunction(set, type); } return set; } -ScalarFunctionSet ArrayDistanceFun::GetFunctions() { - ScalarFunctionSet set("array_distance"); - // Generic array distance function +ScalarFunctionSet ArrayNegativeInnerProductFun::GetFunctions() { + ScalarFunctionSet set("array_negative_inner_product"); for (auto &type : LogicalType::Real()) { - set.AddFunction(ScalarFunction({LogicalType::ARRAY(type), LogicalType::ARRAY(type)}, type, - ArrayGenericBinaryFunction, ArrayGenericBinaryBind)); + AddArrayFoldFunction(set, type); } return set; } ScalarFunctionSet ArrayCosineSimilarityFun::GetFunctions() { ScalarFunctionSet set("array_cosine_similarity"); - // Generic array cosine similarity function for (auto &type : LogicalType::Real()) { - set.AddFunction(ScalarFunction({LogicalType::ARRAY(type), LogicalType::ARRAY(type)}, type, - ArrayGenericBinaryFunction, - ArrayGenericBinaryBind)); + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayCosineDistanceFun::GetFunctions() { + ScalarFunctionSet set("array_cosine_distance"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); } return set; } @@ -285,14 +261,12 @@ ScalarFunctionSet ArrayCosineSimilarityFun::GetFunctions() { ScalarFunctionSet ArrayCrossProductFun::GetFunctions() { ScalarFunctionSet set("array_cross_product"); - // Generic array cross product function - auto double_arr = LogicalType::ARRAY(LogicalType::DOUBLE, 3); + auto float_array = LogicalType::ARRAY(LogicalType::FLOAT, 3); + auto double_array = LogicalType::ARRAY(LogicalType::DOUBLE, 3); set.AddFunction( - ScalarFunction({double_arr, double_arr}, double_arr, ArrayFixedBinaryFunction)); - - auto float_arr = LogicalType::ARRAY(LogicalType::FLOAT, 3); + ScalarFunction({float_array, float_array}, float_array, ArrayFixedCombine)); set.AddFunction( - ScalarFunction({float_arr, float_arr}, float_arr, ArrayFixedBinaryFunction)); + ScalarFunction({double_array, double_array}, double_array, ArrayFixedCombine)); return set; } diff --git a/src/duckdb/src/core_functions/scalar/blob/create_sort_key.cpp b/src/duckdb/src/core_functions/scalar/blob/create_sort_key.cpp index a9142443..5a643809 100644 --- a/src/duckdb/src/core_functions/scalar/blob/create_sort_key.cpp +++ b/src/duckdb/src/core_functions/scalar/blob/create_sort_key.cpp @@ -4,42 +4,10 @@ #include "duckdb/common/radix.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" namespace duckdb { -struct OrderModifiers { - OrderModifiers(OrderType order_type, OrderByNullType null_type) : order_type(order_type), null_type(null_type) { - } - - OrderType order_type; - OrderByNullType null_type; - - bool operator==(const OrderModifiers &other) const { - return order_type == other.order_type && null_type == other.null_type; - } - - static OrderModifiers Parse(const string &val) { - auto lcase = StringUtil::Replace(StringUtil::Lower(val), "_", " "); - OrderType order_type; - if (StringUtil::StartsWith(lcase, "asc")) { - order_type = OrderType::ASCENDING; - } else if (StringUtil::StartsWith(lcase, "desc")) { - order_type = OrderType::DESCENDING; - } else { - throw BinderException("create_sort_key modifier must start with either ASC or DESC"); - } - OrderByNullType null_type; - if (StringUtil::EndsWith(lcase, "nulls first")) { - null_type = OrderByNullType::NULLS_FIRST; - } else if (StringUtil::EndsWith(lcase, "nulls last")) { - null_type = OrderByNullType::NULLS_LAST; - } else { - throw BinderException("create_sort_key modifier must end with either NULLS FIRST or NULLS LAST"); - } - return OrderModifiers(order_type, null_type); - } -}; - struct CreateSortKeyBindData : public FunctionData { vector modifiers; @@ -76,7 +44,7 @@ unique_ptr CreateSortKeyBind(ClientContext &context, ScalarFunctio } // push collations for (idx_t i = 0; i < arguments.size(); i += 2) { - ExpressionBinder::PushCollation(context, arguments[i], arguments[i]->return_type, false); + ExpressionBinder::PushCollation(context, arguments[i], arguments[i]->return_type); } // check if all types are constant bool all_constant = true; @@ -109,7 +77,9 @@ struct SortKeyVectorData { static constexpr data_t BLOB_ESCAPE_CHARACTER = 1; SortKeyVectorData(Vector &input, idx_t size, OrderModifiers modifiers) : vec(input) { - input.ToUnifiedFormat(size, format); + if (size != 0) { + input.ToUnifiedFormat(size, format); + } this->size = size; null_byte = NULL_FIRST_BYTE; @@ -140,7 +110,7 @@ struct SortKeyVectorData { } case PhysicalType::LIST: { auto &child_entry = ListVector::GetEntry(input); - auto child_size = ListVector::GetListSize(input); + auto child_size = size == 0 ? 0 : ListVector::GetListSize(input); child_data.push_back(make_uniq(child_entry, child_size, child_modifiers)); break; } @@ -152,6 +122,9 @@ struct SortKeyVectorData { SortKeyVectorData(const SortKeyVectorData &other) = delete; SortKeyVectorData &operator=(const SortKeyVectorData &) = delete; + void Initialize() { + } + PhysicalType GetPhysicalType() { return vec.GetType().InternalType(); } @@ -176,6 +149,21 @@ struct SortKeyConstantOperator { Radix::EncodeData(result, input); return sizeof(T); } + + static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { + auto result_data = FlatVector::GetData(result); + if (flip_bytes) { + // descending order - so flip bytes + data_t flipped_bytes[sizeof(T)]; + for (idx_t b = 0; b < sizeof(T); b++) { + flipped_bytes[b] = ~input[b]; + } + result_data[result_idx] = Radix::DecodeData(flipped_bytes); + } else { + result_data[result_idx] = Radix::DecodeData(input); + } + return sizeof(T); + } }; struct SortKeyVarcharOperator { @@ -194,6 +182,31 @@ struct SortKeyVarcharOperator { result[input_size] = SortKeyVectorData::STRING_DELIMITER; // null-byte delimiter return input_size + 1; } + + static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { + auto result_data = FlatVector::GetData(result); + // iterate until we encounter the string delimiter to figure out the string length + data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; + if (flip_bytes) { + string_delimiter = ~string_delimiter; + } + idx_t pos; + for (pos = 0; input[pos] != string_delimiter; pos++) { + } + idx_t str_len = pos; + // now allocate the string data and fill it with the decoded data + result_data[result_idx] = StringVector::EmptyString(result, str_len); + auto str_data = data_ptr_cast(result_data[result_idx].GetDataWriteable()); + for (pos = 0; pos < str_len; pos++) { + if (flip_bytes) { + str_data[pos] = (~input[pos]) - 1; + } else { + str_data[pos] = input[pos] - 1; + } + } + result_data[result_idx].Finalize(); + return pos + 1; + } }; struct SortKeyBlobOperator { @@ -228,6 +241,42 @@ struct SortKeyBlobOperator { result[result_offset++] = SortKeyVectorData::STRING_DELIMITER; // null-byte delimiter return result_offset; } + + static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { + auto result_data = FlatVector::GetData(result); + // scan until we find the delimiter, keeping in mind escapes + data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; + data_t escape_character = SortKeyVectorData::BLOB_ESCAPE_CHARACTER; + if (flip_bytes) { + string_delimiter = ~string_delimiter; + escape_character = ~escape_character; + } + idx_t blob_len = 0; + idx_t pos; + for (pos = 0; input[pos] != string_delimiter; pos++) { + blob_len++; + if (input[pos] == escape_character) { + // escape character - skip the next byte + pos++; + } + } + // now allocate the blob data and fill it with the decoded data + result_data[result_idx] = StringVector::EmptyString(result, blob_len); + auto str_data = data_ptr_cast(result_data[result_idx].GetDataWriteable()); + for (idx_t input_pos = 0, result_pos = 0; input_pos < pos; input_pos++) { + if (input[input_pos] == escape_character) { + // if we encounter an escape character - copy the NEXT byte + input_pos++; + } + if (flip_bytes) { + str_data[result_pos++] = ~input[input_pos]; + } else { + str_data[result_pos++] = input[input_pos]; + } + } + result_data[result_idx].Finalize(); + return pos + 1; + } }; struct SortKeyListEntry { @@ -402,7 +451,7 @@ static void GetSortKeyLengthRecursive(SortKeyVectorData &vector_data, SortKeyChu } } -static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo &result) { +static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo &result, SortKeyChunk chunk) { // top-level method auto physical_type = vector_data.GetPhysicalType(); if (TypeIsConstantSize(physical_type)) { @@ -411,7 +460,11 @@ static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo & result.constant_length += GetTypeIdSize(physical_type); return; } - GetSortKeyLengthRecursive(vector_data, SortKeyChunk(0, vector_data.size), result); + GetSortKeyLengthRecursive(vector_data, chunk, result); +} + +static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo &result) { + GetSortKeyLength(vector_data, result, SortKeyChunk(0, vector_data.size)); } //===--------------------------------------------------------------------===// @@ -633,7 +686,7 @@ static void FinalizeSortData(Vector &result, idx_t size) { case LogicalTypeId::BIGINT: { auto result_data = FlatVector::GetData(result); for (idx_t r = 0; r < size; r++) { - result_data[r] = BSwap(result_data[r]); + result_data[r] = BSwap(result_data[r]); } break; } @@ -642,40 +695,274 @@ static void FinalizeSortData(Vector &result, idx_t size) { } } -static void CreateSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &bind_data = state.expr.Cast().bind_info->Cast(); - - // prepare the sort key data - vector> sort_key_data; - for (idx_t c = 0; c < args.ColumnCount(); c += 2) { - sort_key_data.push_back(make_uniq(args.data[c], args.size(), bind_data.modifiers[c / 2])); - } - +static void CreateSortKeyInternal(vector> &sort_key_data, + const vector &modifiers, Vector &result, idx_t row_count) { // two phases // a) get the length of the final sorted key // b) allocate the sorted key and construct // we do all of this in a vectorized manner - SortKeyLengthInfo key_lengths(args.size()); + SortKeyLengthInfo key_lengths(row_count); for (auto &vector_data : sort_key_data) { GetSortKeyLength(*vector_data, key_lengths); } // allocate the empty sort keys - auto data_pointers = unique_ptr(new data_ptr_t[args.size()]); - PrepareSortData(result, args.size(), key_lengths, data_pointers.get()); + auto data_pointers = unique_ptr(new data_ptr_t[row_count]); + PrepareSortData(result, row_count, key_lengths, data_pointers.get()); unsafe_vector offsets; - offsets.resize(args.size(), 0); + offsets.resize(row_count, 0); // now construct the sort keys for (idx_t c = 0; c < sort_key_data.size(); c++) { - SortKeyConstructInfo info(bind_data.modifiers[c], offsets, data_pointers.get()); + SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); ConstructSortKey(*sort_key_data[c], info); } - FinalizeSortData(result, args.size()); + FinalizeSortData(result, row_count); +} + +void CreateSortKeyHelpers::CreateSortKey(Vector &input, idx_t input_count, OrderModifiers order_modifier, + Vector &result) { + // prepare the sort key data + vector modifiers {order_modifier}; + vector> sort_key_data; + sort_key_data.push_back(make_uniq(input, input_count, order_modifier)); + + CreateSortKeyInternal(sort_key_data, modifiers, result, input_count); +} + +void CreateSortKeyHelpers::CreateSortKeyWithValidity(Vector &input, Vector &result, const OrderModifiers &modifiers, + const idx_t count) { + CreateSortKey(input, count, modifiers, result); + UnifiedVectorFormat format; + input.ToUnifiedFormat(count, format); + auto &validity = FlatVector::Validity(result); + + for (idx_t i = 0; i < count; i++) { + auto idx = format.sel->get_index(i); + if (!format.validity.RowIsValid(idx)) { + validity.SetInvalid(i); + } + } +} + +static void CreateSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &bind_data = state.expr.Cast().bind_info->Cast(); + + // prepare the sort key data + vector> sort_key_data; + for (idx_t c = 0; c < args.ColumnCount(); c += 2) { + sort_key_data.push_back(make_uniq(args.data[c], args.size(), bind_data.modifiers[c / 2])); + } + CreateSortKeyInternal(sort_key_data, bind_data.modifiers, result, args.size()); + if (args.AllConstant()) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } } +//===--------------------------------------------------------------------===// +// Decode Sort Key +//===--------------------------------------------------------------------===// +struct DecodeSortKeyData { + explicit DecodeSortKeyData(OrderModifiers modifiers, string_t &sort_key) + : data(const_data_ptr_cast(sort_key.GetData())), size(sort_key.GetSize()), position(0), + flip_bytes(modifiers.order_type == OrderType::DESCENDING) { + } + + const_data_ptr_t data; + idx_t size; + idx_t position; + bool flip_bytes; +}; + +void DecodeSortKeyRecursive(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx); + +template +void TemplatedDecodeSortKey(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx) { + auto validity_byte = decode_data.data[decode_data.position]; + decode_data.position++; + if (validity_byte == vector_data.null_byte) { + // NULL value + FlatVector::Validity(result).SetInvalid(result_idx); + return; + } + idx_t increment = OP::Decode(decode_data.data + decode_data.position, result, result_idx, decode_data.flip_bytes); + decode_data.position += increment; +} + +void DecodeSortKeyStruct(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx) { + // check if the top-level is valid or not + auto validity_byte = decode_data.data[decode_data.position]; + decode_data.position++; + if (validity_byte == vector_data.null_byte) { + // entire struct is NULL + // note that we still deserialize the children + FlatVector::Validity(result).SetInvalid(result_idx); + } + // recurse into children + auto &child_entries = StructVector::GetEntries(result); + for (idx_t c = 0; c < child_entries.size(); c++) { + auto &child_entry = child_entries[c]; + DecodeSortKeyRecursive(decode_data, *vector_data.child_data[c], *child_entry, result_idx); + } +} + +void DecodeSortKeyList(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx) { + // check if the top-level is valid or not + auto validity_byte = decode_data.data[decode_data.position]; + decode_data.position++; + if (validity_byte == vector_data.null_byte) { + // entire list is NULL + FlatVector::Validity(result).SetInvalid(result_idx); + return; + } + // list is valid - decode child elements + // we don't know how many there will be + // decode child elements until we encounter the list delimiter + auto list_delimiter = SortKeyVectorData::LIST_DELIMITER; + if (decode_data.flip_bytes) { + list_delimiter = ~list_delimiter; + } + auto list_data = FlatVector::GetData(result); + auto &child_vector = ListVector::GetEntry(result); + // get the current list size + auto start_list_size = ListVector::GetListSize(result); + auto new_list_size = start_list_size; + // loop until we find the list delimiter + while (decode_data.data[decode_data.position] != list_delimiter) { + // found a valid entry here - decode it + // first reserve space for it + new_list_size++; + ListVector::Reserve(result, new_list_size); + + // now decode the entry + DecodeSortKeyRecursive(decode_data, *vector_data.child_data[0], child_vector, new_list_size - 1); + } + // skip the list delimiter + decode_data.position++; + // set the list_entry_t information and update the list size + list_data[result_idx].length = new_list_size - start_list_size; + list_data[result_idx].offset = start_list_size; + ListVector::SetListSize(result, new_list_size); +} + +void DecodeSortKeyArray(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx) { + // check if the top-level is valid or not + auto validity_byte = decode_data.data[decode_data.position]; + decode_data.position++; + if (validity_byte == vector_data.null_byte) { + // entire array is NULL + // note that we still read the child elements + FlatVector::Validity(result).SetInvalid(result_idx); + } + // array is valid - decode child elements + // arrays need to encode exactly array_size child elements + // however the decoded data still contains a list delimiter + // we use this delimiter to verify we successfully decoded the entire array + auto list_delimiter = SortKeyVectorData::LIST_DELIMITER; + if (decode_data.flip_bytes) { + list_delimiter = ~list_delimiter; + } + auto &child_vector = ArrayVector::GetEntry(result); + auto array_size = ArrayType::GetSize(result.GetType()); + + idx_t found_elements = 0; + auto child_start = array_size * result_idx; + // loop until we find the list delimiter + while (decode_data.data[decode_data.position] != list_delimiter) { + found_elements++; + if (found_elements > array_size) { + // error - found too many elements + break; + } + // now decode the entry + DecodeSortKeyRecursive(decode_data, *vector_data.child_data[0], child_vector, child_start + found_elements - 1); + } + // skip the list delimiter + decode_data.position++; + if (found_elements != array_size) { + throw InvalidInputException("Failed to decode array - found %d elements but expected %d", found_elements, + array_size); + } +} + +void DecodeSortKeyRecursive(DecodeSortKeyData &decode_data, SortKeyVectorData &vector_data, Vector &result, + idx_t result_idx) { + switch (result.GetType().InternalType()) { + case PhysicalType::BOOL: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::UINT8: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INT8: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::UINT16: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INT16: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::UINT32: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INT32: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::UINT64: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INT64: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::FLOAT: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::DOUBLE: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INTERVAL: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::UINT128: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::INT128: + TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::VARCHAR: + if (vector_data.vec.GetType().id() == LogicalTypeId::VARCHAR) { + TemplatedDecodeSortKey(decode_data, vector_data, result, result_idx); + } else { + TemplatedDecodeSortKey(decode_data, vector_data, result, result_idx); + } + break; + case PhysicalType::STRUCT: + DecodeSortKeyStruct(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::LIST: + DecodeSortKeyList(decode_data, vector_data, result, result_idx); + break; + case PhysicalType::ARRAY: + DecodeSortKeyArray(decode_data, vector_data, result, result_idx); + break; + default: + throw NotImplementedException("Unsupported type %s in DecodeSortKey", vector_data.vec.GetType()); + } +} + +void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, + OrderModifiers modifiers) { + SortKeyVectorData sort_key_data(result, 0, modifiers); + DecodeSortKeyData decode_data(modifiers, sort_key); + DecodeSortKeyRecursive(decode_data, sort_key_data, result, result_idx); +} + ScalarFunction CreateSortKeyFun::GetFunction() { ScalarFunction sort_key_function("create_sort_key", {LogicalType::ANY}, LogicalType::BLOB, CreateSortKeyFunction, CreateSortKeyBind); diff --git a/src/duckdb/src/core_functions/scalar/date/date_diff.cpp b/src/duckdb/src/core_functions/scalar/date/date_diff.cpp index d4514b90..6266dda3 100644 --- a/src/duckdb/src/core_functions/scalar/date/date_diff.cpp +++ b/src/duckdb/src/core_functions/scalar/date/date_diff.cpp @@ -91,8 +91,8 @@ struct DateDiff { struct WeekOperator { template static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(Date::GetMondayOfCurrentWeek(enddate)) / Interval::SECS_PER_WEEK - - Date::Epoch(Date::GetMondayOfCurrentWeek(startdate)) / Interval::SECS_PER_WEEK; + // Weeks do not count Monday crossings, just distance + return (enddate.days - startdate.days) / Interval::DAYS_PER_WEEK; } }; diff --git a/src/duckdb/src/core_functions/scalar/date/date_part.cpp b/src/duckdb/src/core_functions/scalar/date/date_part.cpp index 3d0b7577..c234e1e3 100644 --- a/src/duckdb/src/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/src/core_functions/scalar/date/date_part.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/date_part_specifier.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/enum_util.hpp" #include "duckdb/common/types/date.hpp" @@ -11,6 +12,7 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/date_lookup_cache.hpp" namespace duckdb { @@ -98,6 +100,20 @@ static unique_ptr PropagateSimpleDatePartStatistics(vector +struct DateCacheLocalState : public FunctionLocalState { + explicit DateCacheLocalState() { + } + + DateLookupCache cache; +}; + +template +unique_ptr InitDateCacheLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + return make_uniq>(); +} + struct DatePart { template static unique_ptr PropagateDatePartStatistics(vector &child_stats, @@ -403,6 +419,18 @@ struct DatePart { } }; + struct NanosecondsOperator { + template + static inline TR Operation(TA input) { + return MicrosecondsOperator::Operation(input) * Interval::NANOS_PER_MICRO; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000000000>(input.child_stats); + } + }; + struct MicrosecondsOperator { template static inline TR Operation(TA input) { @@ -466,7 +494,7 @@ struct DatePart { struct EpochOperator { template static inline TR Operation(TA input) { - return Date::Epoch(input); + return TR(Date::Epoch(input)); } template @@ -720,7 +748,7 @@ struct DatePart { if (mask & EPOCH) { auto double_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); if (double_data) { - double_data[idx] = Date::Epoch(input); + double_data[idx] = double(Date::Epoch(input)); } } if (mask & DOY) { @@ -732,25 +760,19 @@ struct DatePart { if (mask & JD) { auto double_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); if (double_data) { - double_data[idx] = Date::ExtractJulianDay(input); + double_data[idx] = double(Date::ExtractJulianDay(input)); } } } }; }; -template -static void LastYearFunction(DataChunk &args, ExpressionState &state, Vector &result) { - int32_t last_year = 0; - UnaryExecutor::ExecuteWithNulls(args.data[0], result, args.size(), - [&](T input, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(input)) { - return Date::ExtractYear(input, &last_year); - } else { - mask.SetInvalid(idx); - return 0; - } - }); +template +static void DatePartCachedFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast>(); + UnaryExecutor::ExecuteWithNulls( + args.data[0], result, args.size(), + [&](T input, ValidityMask &mask, idx_t idx) { return lstate.cache.ExtractElement(input, mask, idx); }); } template <> @@ -1073,6 +1095,19 @@ int64_t DatePart::EpochMillisOperator::Operation(dtime_tz_t input) { return DatePart::EpochMillisOperator::Operation(input.time()); } +template <> +int64_t DatePart::NanosecondsOperator::Operation(timestamp_ns_t input) { + if (!Timestamp::IsFinite(input)) { + throw ConversionException("Can't get nanoseconds of infinite TIMESTAMP"); + } + date_t date; + dtime_t time; + int32_t nanos; + Timestamp::Convert(input, date, time, nanos); + // remove everything but the second & nanosecond part + return (time.micros % Interval::MICROS_PER_MINUTE) * Interval::NANOS_PER_MICRO + nanos; +} + template <> int64_t DatePart::MicrosecondsOperator::Operation(timestamp_t input) { D_ASSERT(Timestamp::IsFinite(input)); @@ -1189,7 +1224,7 @@ int64_t DatePart::HoursOperator::Operation(dtime_tz_t input) { template <> double DatePart::EpochOperator::Operation(timestamp_t input) { D_ASSERT(Timestamp::IsFinite(input)); - return Timestamp::GetEpochMicroSeconds(input) / double(Interval::MICROS_PER_SEC); + return double(Timestamp::GetEpochMicroSeconds(input)) / double(Interval::MICROS_PER_SEC); } template <> @@ -1203,7 +1238,7 @@ double DatePart::EpochOperator::Operation(interval_t input) { interval_epoch = interval_days * Interval::SECS_PER_DAY; // we add 0.25 days per year to sort of account for leap days interval_epoch += interval_years * (Interval::SECS_PER_DAY / 4); - return interval_epoch + input.micros / double(Interval::MICROS_PER_SEC); + return double(interval_epoch) + double(input.micros) / double(Interval::MICROS_PER_SEC); } // TODO: We can't propagate interval statistics because we can't easily compare interval_t for order. @@ -1215,7 +1250,7 @@ unique_ptr DatePart::EpochOperator::PropagateStatistics double DatePart::EpochOperator::Operation(dtime_t input) { - return input.micros / double(Interval::MICROS_PER_SEC); + return double(input.micros) / double(Interval::MICROS_PER_SEC); } template <> @@ -1301,7 +1336,7 @@ int64_t DatePart::TimezoneMinuteOperator::Operation(dtime_tz_t input) { template <> double DatePart::JulianDayOperator::Operation(date_t input) { - return Date::ExtractJulianDay(input); + return double(Date::ExtractJulianDay(input)); } template <> @@ -1659,14 +1694,15 @@ static unique_ptr DatePartBind(ClientContext &context, ScalarFunct return nullptr; } +template ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar_function_t ts_func, scalar_function_t interval_func, function_statistics_t date_stats, function_statistics_t ts_stats) { ScalarFunctionSet operator_set; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, nullptr, date_stats)); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, nullptr, ts_stats)); + operator_set.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, + nullptr, date_stats, DATE_CACHE)); + operator_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, + nullptr, ts_stats, DATE_CACHE)); operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); return operator_set; } @@ -1945,20 +1981,24 @@ struct StructDatePart { return result; } }; +template +ScalarFunctionSet GetCachedDatepartFunction() { + return GetGenericDatePartFunction>( + DatePartCachedFunction, DatePartCachedFunction, + ScalarFunction::UnaryFunction, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} ScalarFunctionSet YearFun::GetFunctions() { - return GetGenericDatePartFunction(LastYearFunction, LastYearFunction, - ScalarFunction::UnaryFunction, - DatePart::YearOperator::PropagateStatistics, - DatePart::YearOperator::PropagateStatistics); + return GetCachedDatepartFunction(); } ScalarFunctionSet MonthFun::GetFunctions() { - return GetDatePartFunction(); + return GetCachedDatepartFunction(); } ScalarFunctionSet DayFun::GetFunctions() { - return GetDatePartFunction(); + return GetCachedDatepartFunction(); } ScalarFunctionSet DecadeFun::GetFunctions() { @@ -2065,6 +2105,26 @@ ScalarFunctionSet EpochMsFun::GetFunctions() { return operator_set; } +ScalarFunctionSet NanosecondsFun::GetFunctions() { + using OP = DatePart::NanosecondsOperator; + using TR = int64_t; + const LogicalType &result_type = LogicalType::BIGINT; + auto operator_set = GetTimePartFunction(); + + auto ns_func = DatePart::UnaryFunction; + auto ns_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_NS}, result_type, ns_func, nullptr, nullptr, ns_stats)); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + + return operator_set; +} + ScalarFunctionSet MicrosecondsFun::GetFunctions() { return GetTimePartFunction(); } diff --git a/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp index 0493c71b..6e5bcc70 100644 --- a/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp +++ b/src/duckdb/src/core_functions/scalar/date/date_trunc.cpp @@ -138,7 +138,7 @@ struct DateTrunc { dtime_t time; Timestamp::Convert(input, date, time); Time::Convert(time, hour, min, sec, micros); - micros -= micros % Interval::MICROS_PER_MSEC; + micros -= UnsafeNumericCast(micros % Interval::MICROS_PER_MSEC); return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, micros)); } }; diff --git a/src/duckdb/src/core_functions/scalar/date/make_date.cpp b/src/duckdb/src/core_functions/scalar/date/make_date.cpp index 7b818438..1ef81e37 100644 --- a/src/duckdb/src/core_functions/scalar/date/make_date.cpp +++ b/src/duckdb/src/core_functions/scalar/date/make_date.cpp @@ -66,9 +66,9 @@ struct MakeTimeOperator { if (ss < 0 || ss > Interval::SECS_PER_MINUTE) { ss_32 = Cast::Operation(ss); } else { - ss_32 = UnsafeNumericCast(ss); + ss_32 = LossyNumericCast(ss); } - auto micros = UnsafeNumericCast(std::round((ss - ss_32) * Interval::MICROS_PER_SEC)); + auto micros = LossyNumericCast(std::round((ss - ss_32) * Interval::MICROS_PER_SEC)); if (!Time::IsValidTime(hh_32, mm_32, ss_32, micros)) { throw ConversionException("Time out of range: %d:%d:%d.%d", hh_32, mm_32, ss_32, micros); diff --git a/src/duckdb/src/core_functions/scalar/date/strftime.cpp b/src/duckdb/src/core_functions/scalar/date/strftime.cpp index 01c907a5..8aa34d32 100644 --- a/src/duckdb/src/core_functions/scalar/date/strftime.cpp +++ b/src/duckdb/src/core_functions/scalar/date/strftime.cpp @@ -80,6 +80,19 @@ static void StrfTimeFunctionTimestamp(DataChunk &args, ExpressionState &state, V info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result, args.size()); } +template +static void StrfTimeFunctionTimestampNS(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + if (info.is_null) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + info.format.ConvertTimestampNSVector(args.data[REVERSED ? 1 : 0], result, args.size()); +} + ScalarFunctionSet StrfTimeFun::GetFunctions() { ScalarFunctionSet strftime; @@ -87,10 +100,14 @@ ScalarFunctionSet StrfTimeFun::GetFunctions() { StrfTimeFunctionDate, StrfTimeBindFunction)); strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::VARCHAR}, LogicalType::VARCHAR, StrfTimeFunctionTimestamp, StrfTimeBindFunction)); + strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_NS, LogicalType::VARCHAR}, LogicalType::VARCHAR, + StrfTimeFunctionTimestampNS, StrfTimeBindFunction)); strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::VARCHAR, StrfTimeFunctionDate, StrfTimeBindFunction)); strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::VARCHAR, StrfTimeFunctionTimestamp, StrfTimeBindFunction)); + strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP_NS}, LogicalType::VARCHAR, + StrfTimeFunctionTimestampNS, StrfTimeBindFunction)); return strftime; } @@ -126,60 +143,29 @@ struct StrpTimeBindData : public FunctionData { } }; -static unique_ptr StrpTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw InvalidInputException(*arguments[0], "strptime format must be a constant"); - } - Value format_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - string format_string; - StrpTimeFormat format; - if (format_value.IsNull()) { - return make_uniq(format, format_string); - } else if (format_value.type().id() == LogicalTypeId::VARCHAR) { - format_string = format_value.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, error); - } - if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } - return make_uniq(format, format_string); - } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { - const auto &children = ListValue::GetChildren(format_value); - if (children.empty()) { - throw InvalidInputException(*arguments[0], "strptime format list must not be empty"); - } - vector format_strings; - vector formats; - for (const auto &child : children) { - format_string = child.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, - error); - } - // If any format has UTC offsets, then we have to produce TSTZ - if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } - format_strings.emplace_back(format_string); - formats.emplace_back(format); - } - return make_uniq(formats, format_strings); - } else { - throw InvalidInputException(*arguments[0], "strptime format must be a string"); - } +template +inline T StrpTimeResult(StrpTimeFormat::ParseResult &parsed) { + return parsed.ToTimestamp(); +} + +template <> +inline timestamp_ns_t StrpTimeResult(StrpTimeFormat::ParseResult &parsed) { + return parsed.ToTimestampNS(); +} + +template +inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, T &result, string &error) { + return format.TryParseTimestamp(input, result, error); +} + +template <> +inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, timestamp_ns_t &result, string &error) { + return format.TryParseTimestampNS(input, result, error); } struct StrpTimeFunction { + template static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); auto &info = func_expr.bind_info->Cast(); @@ -196,17 +182,18 @@ struct StrpTimeFunction { ConstantVector::SetNull(result, true); return; } - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { StrpTimeFormat::ParseResult result; for (auto &format : info.formats) { if (format.Parse(input, result)) { - return result.ToTimestamp(); + return StrpTimeResult(result); } } throw InvalidInputException(result.FormatError(input, info.formats[0].format_specifier)); }); } + template static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); auto &info = func_expr.bind_info->Cast(); @@ -217,19 +204,94 @@ struct StrpTimeFunction { return; } - UnaryExecutor::ExecuteWithNulls( - args.data[0], result, args.size(), [&](string_t input, ValidityMask &mask, idx_t idx) { - timestamp_t result; - string error; - for (auto &format : info.formats) { - if (format.TryParseTimestamp(input, result, error)) { - return result; - } - } - - mask.SetInvalid(idx); - return timestamp_t(); - }); + UnaryExecutor::ExecuteWithNulls(args.data[0], result, args.size(), + [&](string_t input, ValidityMask &mask, idx_t idx) { + T result; + string error; + for (auto &format : info.formats) { + if (StrpTimeTryResult(format, input, result, error)) { + return result; + } + } + + mask.SetInvalid(idx); + return T(); + }); + } + + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw InvalidInputException(*arguments[0], "strptime format must be a constant"); + } + Value format_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + string format_string; + StrpTimeFormat format; + if (format_value.IsNull()) { + return make_uniq(format, format_string); + } else if (format_value.type().id() == LogicalTypeId::VARCHAR) { + format_string = format_value.ToString(); + format.format_specifier = format_string; + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, + error); + } + if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { + bound_function.return_type = LogicalType::TIMESTAMP_TZ; + } else if (format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED)) { + bound_function.return_type = LogicalType::TIMESTAMP_NS; + if (bound_function.name == "strptime") { + bound_function.function = Parse; + } else { + bound_function.function = TryParse; + } + } + return make_uniq(format, format_string); + } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { + const auto &children = ListValue::GetChildren(format_value); + if (children.empty()) { + throw InvalidInputException(*arguments[0], "strptime format list must not be empty"); + } + vector format_strings; + vector formats; + bool has_offset = false; + bool has_nanos = false; + + for (const auto &child : children) { + format_string = child.ToString(); + format.format_specifier = format_string; + string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); + if (!error.empty()) { + throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, + error); + } + has_offset = has_offset || format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET); + has_nanos = has_nanos || format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED); + format_strings.emplace_back(format_string); + formats.emplace_back(format); + } + + if (has_offset) { + // If any format has UTC offsets, then we have to produce TSTZ + bound_function.return_type = LogicalType::TIMESTAMP_TZ; + } else if (has_nanos) { + // If any format has nanoseconds, then we have to produce TSNS + // unless there is an offset, in which case we produce + bound_function.return_type = LogicalType::TIMESTAMP_NS; + if (bound_function.name == "strptime") { + bound_function.function = Parse; + } else { + bound_function.function = TryParse; + } + } + return make_uniq(formats, format_strings); + } else { + throw InvalidInputException(*arguments[0], "strptime format must be a string"); + } } }; @@ -238,12 +300,12 @@ ScalarFunctionSet StrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::Parse, StrpTimeBindFunction); + StrpTimeFunction::Parse, StrpTimeFunction::Bind); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; strptime.AddFunction(fun); - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, - StrpTimeBindFunction); + fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, + StrpTimeFunction::Parse, StrpTimeFunction::Bind); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; strptime.AddFunction(fun); return strptime; @@ -254,12 +316,12 @@ ScalarFunctionSet TryStrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::TryParse, StrpTimeBindFunction); + StrpTimeFunction::TryParse, StrpTimeFunction::Bind); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; try_strptime.AddFunction(fun); - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, - StrpTimeBindFunction); + fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, + StrpTimeFunction::TryParse, StrpTimeFunction::Bind); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; try_strptime.AddFunction(fun); diff --git a/src/duckdb/src/core_functions/scalar/date/to_interval.cpp b/src/duckdb/src/core_functions/scalar/date/to_interval.cpp index d5ff5de6..e16111f8 100644 --- a/src/duckdb/src/core_functions/scalar/date/to_interval.cpp +++ b/src/duckdb/src/core_functions/scalar/date/to_interval.cpp @@ -8,7 +8,7 @@ namespace duckdb { template <> bool TryMultiplyOperator::Operation(double left, int64_t right, int64_t &result) { - return TryCast::Operation(left * right, result); + return TryCast::Operation(left * double(right), result); } struct ToMillenniaOperator { diff --git a/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp index 5722cf58..ddf07c3d 100644 --- a/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp +++ b/src/duckdb/src/core_functions/scalar/enum/enum_functions.cpp @@ -28,7 +28,7 @@ static void EnumRangeFunction(DataChunk &input, ExpressionState &state, Vector & for (idx_t i = 0; i < enum_size; i++) { enum_values.emplace_back(enum_vector.GetValue(i)); } - auto val = Value::LIST(enum_values); + auto val = Value::LIST(LogicalType::VARCHAR, enum_values); result.Reference(val); } diff --git a/src/duckdb/src/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/src/core_functions/scalar/generic/can_implicitly_cast.cpp new file mode 100644 index 00000000..37b25d48 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/generic/can_implicitly_cast.cpp @@ -0,0 +1,40 @@ +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast_rules.hpp" + +namespace duckdb { + +bool CanCastImplicitly(ClientContext &context, const LogicalType &source, const LogicalType &target) { + return CastFunctionSet::Get(context).ImplicitCastCost(source, target) >= 0; +} + +static void CanCastImplicitlyFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + bool can_cast_implicitly = CanCastImplicitly(context, args.data[0].GetType(), args.data[1].GetType()); + auto v = Value::BOOLEAN(can_cast_implicitly); + result.Reference(v); +} + +unique_ptr BindCanCastImplicitlyExpression(FunctionBindExpressionInput &input) { + auto &source_type = input.function.children[0]->return_type; + auto &target_type = input.function.children[1]->return_type; + if (source_type.id() == LogicalTypeId::UNKNOWN || source_type.id() == LogicalTypeId::SQLNULL || + target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::SQLNULL) { + // parameter - unknown return type + return nullptr; + } + // emit a constant expression + return make_uniq( + Value::BOOLEAN(CanCastImplicitly(input.context, source_type, target_type))); +} + +ScalarFunction CanCastImplicitlyFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, CanCastImplicitlyFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.bind_expression = BindCanCastImplicitlyExpression; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/error.cpp b/src/duckdb/src/core_functions/scalar/generic/error.cpp index 9c172e87..e9047378 100644 --- a/src/duckdb/src/core_functions/scalar/generic/error.cpp +++ b/src/duckdb/src/core_functions/scalar/generic/error.cpp @@ -11,7 +11,7 @@ struct ErrorOperator { }; ScalarFunction ErrorFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::SQLNULL, + auto fun = ScalarFunction("error", {LogicalType::VARCHAR}, LogicalType::SQLNULL, ScalarFunction::UnaryFunction); // Set the function with side effects to avoid the optimization. fun.stability = FunctionStability::VOLATILE; diff --git a/src/duckdb/src/core_functions/scalar/generic/least.cpp b/src/duckdb/src/core_functions/scalar/generic/least.cpp index 16f859bf..d91b4939 100644 --- a/src/duckdb/src/core_functions/scalar/generic/least.cpp +++ b/src/duckdb/src/core_functions/scalar/generic/least.cpp @@ -1,5 +1,7 @@ #include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { @@ -11,58 +13,121 @@ struct LeastOperator { } }; -template +struct LeastGreatestSortKeyState : public FunctionLocalState { + explicit LeastGreatestSortKeyState(idx_t column_count) + : intermediate(LogicalType::BLOB), modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST) { + vector types; + // initialize sort key chunk + for (idx_t i = 0; i < column_count; i++) { + types.push_back(LogicalType::BLOB); + } + sort_keys.Initialize(Allocator::DefaultAllocator(), types); + } + + DataChunk sort_keys; + Vector intermediate; + OrderModifiers modifiers; +}; + +unique_ptr LeastGreatestSortKeyInit(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + return make_uniq(expr.children.size()); +} + +template +struct StandardLeastGreatest { + static constexpr bool IS_STRING = STRING; + + static DataChunk &Prepare(DataChunk &args, ExpressionState &) { + return args; + } + + static Vector &TargetVector(Vector &result, ExpressionState &) { + return result; + } + + static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &) { + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < rows; i++) { + if (!result_has_value[i]) { + result_mask.SetInvalid(i); + } + } + } +}; + +struct SortKeyLeastGreatest { + static constexpr bool IS_STRING = false; + + static DataChunk &Prepare(DataChunk &args, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + lstate.sort_keys.Reset(); + for (idx_t c_idx = 0; c_idx < args.ColumnCount(); c_idx++) { + CreateSortKeyHelpers::CreateSortKey(args.data[c_idx], args.size(), lstate.modifiers, + lstate.sort_keys.data[c_idx]); + } + lstate.sort_keys.SetCardinality(args.size()); + return lstate.sort_keys; + } + + static Vector &TargetVector(Vector &result, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return lstate.intermediate; + } + + static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + auto result_keys = FlatVector::GetData(lstate.intermediate); + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < rows; i++) { + if (!result_has_value[i]) { + result_mask.SetInvalid(i); + } else { + CreateSortKeyHelpers::DecodeSortKey(result_keys[i], result, i, lstate.modifiers); + } + } + } +}; + +template > static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vector &result) { if (args.ColumnCount() == 1) { // single input: nop result.Reference(args.data[0]); return; } + auto &input = BASE_OP::Prepare(args, state); + auto &result_vector = BASE_OP::TargetVector(result, state); + auto result_type = VectorType::CONSTANT_VECTOR; - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { // non-constant input: result is not a constant vector result_type = VectorType::FLAT_VECTOR; } - if (IS_STRING) { + if (BASE_OP::IS_STRING) { // for string vectors we add a reference to the heap of the children - StringVector::AddHeapReference(result, args.data[col_idx]); + StringVector::AddHeapReference(result_vector, input.data[col_idx]); } } - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - // copy over the first column - bool result_has_value[STANDARD_VECTOR_SIZE]; - { - UnifiedVectorFormat vdata; - args.data[0].ToUnifiedFormat(args.size(), vdata); - auto input_data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < args.size(); i++) { - auto vindex = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(vindex)) { - result_data[i] = input_data[vindex]; - result_has_value[i] = true; - } else { - result_has_value[i] = false; - } - } - } - // now handle the remainder of the columns - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && - ConstantVector::IsNull(args.data[col_idx])) { + auto result_data = FlatVector::GetData(result_vector); + bool result_has_value[STANDARD_VECTOR_SIZE] {false}; + // perform the operation column-by-column + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + if (input.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && + ConstantVector::IsNull(input.data[col_idx])) { // ignore null vector continue; } UnifiedVectorFormat vdata; - args.data[col_idx].ToUnifiedFormat(args.size(), vdata); + input.data[col_idx].ToUnifiedFormat(input.size(), vdata); auto input_data = UnifiedVectorFormat::GetData(vdata); if (!vdata.validity.AllValid()) { // potential new null entries: have to check the null mask - for (idx_t i = 0; i < args.size(); i++) { + for (idx_t i = 0; i < input.size(); i++) { auto vindex = vdata.sel->get_index(i); if (vdata.validity.RowIsValid(vindex)) { // not a null entry: perform the operation and add to new set @@ -75,7 +140,7 @@ static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vecto } } else { // no new null entries: only need to perform the operation - for (idx_t i = 0; i < args.size(); i++) { + for (idx_t i = 0; i < input.size(); i++) { auto vindex = vdata.sel->get_index(i); auto ivalue = input_data[vindex]; @@ -86,51 +151,89 @@ static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vecto } } } - for (idx_t i = 0; i < args.size(); i++) { - if (!result_has_value[i]) { - result_mask.SetInvalid(i); + BASE_OP::FinalizeResult(input.size(), result_has_value, result, state); + result.SetVectorType(result_type); +} + +template +unique_ptr BindLeastGreatest(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + LogicalType child_type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); + for (idx_t i = 1; i < arguments.size(); i++) { + auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); + if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { + throw BinderException(arguments[i]->query_location, + "Cannot combine types of %s and %s - an explicit cast is required", + child_type.ToString(), arg_type.ToString()); } } - result.SetVectorType(result_type); + switch (child_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::INTEGER_LITERAL: + child_type = IntegerLiteral::GetType(child_type); + break; + case LogicalTypeId::STRING_LITERAL: + child_type = LogicalType::VARCHAR; + break; + default: + break; + } + switch (child_type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT16: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT32: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT64: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT128: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::DOUBLE: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::VARCHAR: + bound_function.function = LeastGreatestFunction>; + break; + default: + // fallback with sort keys + bound_function.function = LeastGreatestFunction; + bound_function.init_local_state = LeastGreatestSortKeyInit; + break; + } + bound_function.arguments[0] = child_type; + bound_function.varargs = child_type; + bound_function.return_type = child_type; + return nullptr; } -template -ScalarFunction GetLeastGreatestFunction(const LogicalType &type) { - return ScalarFunction({type}, type, LeastGreatestFunction, nullptr, nullptr, nullptr, nullptr, type, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING); +template +ScalarFunction GetLeastGreatestFunction() { + return ScalarFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, BindLeastGreatest, nullptr, nullptr, + nullptr, LogicalType::ANY, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING); } template static ScalarFunctionSet GetLeastGreatestFunctions() { ScalarFunctionSet fun_set; - fun_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::BIGINT, LeastGreatestFunction, - nullptr, nullptr, nullptr, nullptr, LogicalType::BIGINT, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction( - {LogicalType::HUGEINT}, LogicalType::HUGEINT, LeastGreatestFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::HUGEINT, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, LeastGreatestFunction, - nullptr, nullptr, nullptr, nullptr, LogicalType::DOUBLE, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - fun_set.AddFunction(ScalarFunction( - {LogicalType::VARCHAR}, LogicalType::VARCHAR, LeastGreatestFunction, nullptr, nullptr, - nullptr, nullptr, LogicalType::VARCHAR, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::DATE)); - - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIMESTAMP_TZ)); - fun_set.AddFunction(GetLeastGreatestFunction(LogicalType::TIME_TZ)); + fun_set.AddFunction(GetLeastGreatestFunction()); return fun_set; } ScalarFunctionSet LeastFun::GetFunctions() { - return GetLeastGreatestFunctions(); + return GetLeastGreatestFunctions(); } ScalarFunctionSet GreatestFun::GetFunctions() { - return GetLeastGreatestFunctions(); + return GetLeastGreatestFunctions(); } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/generic/typeof.cpp b/src/duckdb/src/core_functions/scalar/generic/typeof.cpp index a1b01f8c..b74a0cef 100644 --- a/src/duckdb/src/core_functions/scalar/generic/typeof.cpp +++ b/src/duckdb/src/core_functions/scalar/generic/typeof.cpp @@ -1,4 +1,6 @@ #include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { @@ -7,9 +9,20 @@ static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &resu result.Reference(v); } +unique_ptr BindTypeOfFunctionExpression(FunctionBindExpressionInput &input) { + auto &return_type = input.function.children[0]->return_type; + if (return_type.id() == LogicalTypeId::UNKNOWN || return_type.id() == LogicalTypeId::SQLNULL) { + // parameter - unknown return type + return nullptr; + } + // emit a constant expression + return make_uniq(Value(return_type.ToString())); +} + ScalarFunction TypeOfFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.bind_expression = BindTypeOfFunctionExpression; return fun; } diff --git a/src/duckdb/src/core_functions/scalar/list/array_slice.cpp b/src/duckdb/src/core_functions/scalar/list/array_slice.cpp index 8075aeea..3cc0960d 100644 --- a/src/duckdb/src/core_functions/scalar/list/array_slice.cpp +++ b/src/duckdb/src/core_functions/scalar/list/array_slice.cpp @@ -42,7 +42,7 @@ unique_ptr ListSliceBindData::Copy() const { template static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { if (step < 0) { - step = abs(step); + step = AbsValue(step); } if (step == 0 && svalid) { throw InvalidInputException("Slice step cannot be zero"); diff --git a/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp index fe7a95fc..e423c1a7 100644 --- a/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/src/core_functions/scalar/list/list_aggregates.cpp @@ -10,6 +10,8 @@ #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression_binder.hpp" #include "duckdb/function/function_binder.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" +#include "duckdb/common/owning_string_map.hpp" namespace duckdb { @@ -55,7 +57,7 @@ struct ListAggregatesBindData : public FunctionData { } static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { - auto result = deserializer.ReadPropertyWithDefault>( + auto result = deserializer.ReadPropertyWithExplicitDefault>( 100, "bind_data", unique_ptr(nullptr)); if (!result) { return ListAggregatesBindFailure(bound_function); @@ -93,16 +95,23 @@ struct StateVector { struct FinalizeValueFunctor { template - static Value FinalizeValue(T first) { - return Value::CreateValue(first); + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = value; } }; struct FinalizeStringValueFunctor { template - static Value FinalizeValue(T first) { - string_t value = first; - return Value::CreateValue(value); + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = StringVector::AddStringOrBlob(result, value); + } +}; + +struct FinalizeGenericValueFunctor { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + CreateSortKeyHelpers::DecodeSortKey(value, result, offset, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); } }; @@ -115,32 +124,44 @@ struct AggregateFunctor { struct DistinctFunctor { template > static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; - - auto result_data = FlatVector::GetData(result); + auto states = UnifiedVectorFormat::GetData *>(sdata); - idx_t offset = 0; + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + // figure out how much space we need for (idx_t i = 0; i < count; i++) { - - auto state = states[sdata.sel->get_index(i)]; - result_data[i].offset = offset; - - if (!state->hist) { - result_data[i].length = 0; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { continue; } + new_entries += state.hist->size(); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &child_elements = ListVector::GetEntry(result); + auto list_entries = FlatVector::GetData(result); - result_data[i].length = state->hist->size(); - offset += state->hist->size(); + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i; + auto &state = *states[sdata.sel->get_index(i)]; + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + if (!state.hist) { + list_entry.length = 0; + continue; + } - for (auto &entry : *state->hist) { - Value bucket_value = OP::template FinalizeValue(entry.first); - ListVector::PushBack(result, bucket_value); + for (auto &entry : *state.hist) { + OP::template HistogramFinalize(entry.first, child_elements, current_offset); + current_offset++; } + list_entry.length = current_offset - list_entry.offset; } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); result.Verify(count); } }; @@ -148,13 +169,11 @@ struct DistinctFunctor { struct UniqueFunctor { template > static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = (HistogramAggState **)sdata.data; + auto states = UnifiedVectorFormat::GetData *>(sdata); auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { auto state = states[sdata.sel->get_index(i)]; @@ -163,7 +182,6 @@ struct UniqueFunctor { result_data[i] = 0; continue; } - result_data[i] = state->hist->size(); } result.Verify(count); @@ -206,8 +224,8 @@ static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vect auto list_entries = UnifiedVectorFormat::GetData(lists_data); // state_buffer holds the state for each list of this chunk - idx_t size = aggr.function.state_size(); - auto state_buffer = make_unsafe_uniq_array(size * count); + idx_t size = aggr.function.state_size(aggr.function); + auto state_buffer = make_unsafe_uniq_array_uninitialized(size * count); // state vector for initialize and finalize StateVector state_vector(count, info.aggr_expr->Copy()); @@ -226,7 +244,7 @@ static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vect // initialize the state for this list auto state_ptr = state_buffer.get() + size * i; states[i] = state_ptr; - aggr.function.initialize(states[i]); + aggr.function.initialize(aggr.function, states[i]); auto lists_index = lists_data.sel->get_index(i); const auto &list_entry = list_entries[lists_index]; @@ -305,49 +323,12 @@ static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vect result, state_vector.state_vector, count); break; case PhysicalType::INT32: - if (key_type.id() == LogicalTypeId::DATE) { - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - } else { - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - } + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); break; case PhysicalType::INT64: - switch (key_type.id()) { - case LogicalTypeId::TIME: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIME_TZ: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_MS: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_NS: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_SEC: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case LogicalTypeId::TIMESTAMP_TZ: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - default: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - } + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); break; case PhysicalType::FLOAT: FUNCTION_FUNCTOR::template ListExecuteFunction( @@ -358,11 +339,15 @@ static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vect result, state_vector.state_vector, count); break; case PhysicalType::VARCHAR: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); + FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, + count); break; default: - throw InternalException("Unimplemented histogram aggregate"); + FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, + count); + break; } } @@ -486,8 +471,7 @@ static unique_ptr ListAggregatesBind(ClientContext &context, Scala // create the unordered map histogram function D_ASSERT(best_function.arguments.size() == 1); - auto key_type = best_function.arguments[0]; - auto aggr_function = HistogramFun::GetHistogramUnorderedMap(key_type); + auto aggr_function = HistogramFun::GetHistogramUnorderedMap(child_type); return ListAggregatesBindFunction(context, bound_function, child_type, aggr_function, arguments); } diff --git a/src/duckdb/src/core_functions/scalar/list/list_distance.cpp b/src/duckdb/src/core_functions/scalar/list/list_distance.cpp index aa70e4a1..23e19f87 100644 --- a/src/duckdb/src/core_functions/scalar/list/list_distance.cpp +++ b/src/duckdb/src/core_functions/scalar/list/list_distance.cpp @@ -1,58 +1,64 @@ #include "duckdb/core_functions/scalar/list_functions.hpp" -#include +#include "duckdb/core_functions/array_kernels.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { -template -static void ListDistance(DataChunk &args, ExpressionState &, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); +//------------------------------------------------------------------------------ +// Generic "fold" function +//------------------------------------------------------------------------------ +// Given two lists of the same size, combine and reduce their elements into a +// single scalar value. + +template +static void ListGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; auto count = args.size(); - auto &left = args.data[0]; - auto &right = args.data[1]; - auto left_count = ListVector::GetListSize(left); - auto right_count = ListVector::GetListSize(right); - auto &left_child = ListVector::GetEntry(left); - auto &right_child = ListVector::GetEntry(right); + auto &lhs_vec = args.data[0]; + auto &rhs_vec = args.data[1]; + + const auto lhs_count = ListVector::GetListSize(lhs_vec); + const auto rhs_count = ListVector::GetListSize(rhs_vec); + + auto &lhs_child = ListVector::GetEntry(lhs_vec); + auto &rhs_child = ListVector::GetEntry(rhs_vec); + + lhs_child.Flatten(lhs_count); + rhs_child.Flatten(rhs_count); - D_ASSERT(left_child.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(right_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(lhs_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(rhs_child.GetVectorType() == VectorType::FLAT_VECTOR); - if (!FlatVector::Validity(left_child).CheckAllValid(left_count)) { - throw InvalidInputException("list_distance: left argument can not contain NULL values"); + if (!FlatVector::Validity(lhs_child).CheckAllValid(lhs_count)) { + throw InvalidInputException("%s: left argument can not contain NULL values", func_name); } - if (!FlatVector::Validity(right_child).CheckAllValid(right_count)) { - throw InvalidInputException("list_distance: right argument can not contain NULL values"); + if (!FlatVector::Validity(rhs_child).CheckAllValid(rhs_count)) { + throw InvalidInputException("%s: right argument can not contain NULL values", func_name); } - auto left_data = FlatVector::GetData(left_child); - auto right_data = FlatVector::GetData(right_child); + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); - BinaryExecutor::Execute( - left, right, result, count, [&](list_entry_t left, list_entry_t right) { + BinaryExecutor::ExecuteWithNulls( + lhs_vec, rhs_vec, result, count, + [&](const list_entry_t &left, const list_entry_t &right, ValidityMask &mask, idx_t row_idx) { if (left.length != right.length) { - throw InvalidInputException(StringUtil::Format( - "list_distance: list dimensions must be equal, got left length %d and right length %d", left.length, - right.length)); + throw InvalidInputException( + "%s: list dimensions must be equal, got left length '%d' and right length '%d'", func_name, + left.length, right.length); } - auto dimensions = left.length; - - NUMERIC_TYPE distance = 0; - - auto l_ptr = left_data + left.offset; - auto r_ptr = right_data + right.offset; - - for (idx_t i = 0; i < dimensions; i++) { - auto x = *l_ptr++; - auto y = *r_ptr++; - auto diff = x - y; - distance += diff * diff; + if (!OP::ALLOW_EMPTY && left.length == 0) { + mask.SetInvalid(row_idx); + return TYPE(); } - return std::sqrt(distance); + return OP::Operation(lhs_data + left.offset, rhs_data + right.offset, left.length); }); if (args.AllConstant()) { @@ -60,12 +66,59 @@ static void ListDistance(DataChunk &args, ExpressionState &, Vector &result) { } } +//------------------------------------------------------------------------- +// Function Registration +//------------------------------------------------------------------------- + +template +static void AddListFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { + const auto list = LogicalType::LIST(type); + if (type.id() == LogicalTypeId::FLOAT) { + set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); + } else if (type.id() == LogicalTypeId::DOUBLE) { + set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); + } else { + throw NotImplementedException("List function not implemented for type %s", type.ToString()); + } +} + ScalarFunctionSet ListDistanceFun::GetFunctions() { ScalarFunctionSet set("list_distance"); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::FLOAT), LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::FLOAT, ListDistance)); - set.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::DOUBLE), LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::DOUBLE, ListDistance)); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListInnerProductFun::GetFunctions() { + ScalarFunctionSet set("list_inner_product"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListNegativeInnerProductFun::GetFunctions() { + ScalarFunctionSet set("list_negative_inner_product"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { + ScalarFunctionSet set("list_cosine_similarity"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListCosineDistanceFun::GetFunctions() { + ScalarFunctionSet set("list_cosine_distance"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } return set; } diff --git a/src/duckdb/src/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/src/core_functions/scalar/list/list_has_any_or_all.cpp new file mode 100644 index 00000000..4a3e3509 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/list/list_has_any_or_all.cpp @@ -0,0 +1,227 @@ +#include "duckdb/core_functions/lambda_functions.hpp" +#include "duckdb/core_functions/scalar/list_functions.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { + +static unique_ptr ListHasAnyOrAllBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); + + const auto lhs_is_param = arguments[0]->HasParameter(); + const auto rhs_is_param = arguments[1]->HasParameter(); + + if (lhs_is_param && rhs_is_param) { + throw ParameterNotResolvedException(); + } + + const auto &lhs_list = arguments[0]->return_type; + const auto &rhs_list = arguments[1]->return_type; + + if (lhs_is_param) { + bound_function.arguments[0] = rhs_list; + bound_function.arguments[1] = rhs_list; + return nullptr; + } + if (rhs_is_param) { + bound_function.arguments[0] = lhs_list; + bound_function.arguments[1] = lhs_list; + return nullptr; + } + + bound_function.arguments[0] = lhs_list; + bound_function.arguments[1] = rhs_list; + + const auto &lhs_child = ListType::GetChildType(bound_function.arguments[0]); + const auto &rhs_child = ListType::GetChildType(bound_function.arguments[1]); + + if (lhs_child != LogicalType::SQLNULL && rhs_child != LogicalType::SQLNULL && lhs_child != rhs_child) { + LogicalType common_child; + if (!LogicalType::TryGetMaxLogicalType(context, lhs_child, rhs_child, common_child)) { + throw BinderException("'%s' cannot compare lists of different types: '%s' and '%s'", bound_function.name, + lhs_child.ToString(), rhs_child.ToString()); + } + bound_function.arguments[0] = LogicalType::LIST(common_child); + bound_function.arguments[1] = LogicalType::LIST(common_child); + } + + return nullptr; +} + +static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { + + auto &l_vec = args.data[0]; + auto &r_vec = args.data[1]; + + if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL || + ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = false; + return; + } + + const auto l_size = ListVector::GetListSize(l_vec); + const auto r_size = ListVector::GetListSize(r_vec); + + auto &l_child = ListVector::GetEntry(l_vec); + auto &r_child = ListVector::GetEntry(r_vec); + + // Setup unified formats for the list elements + UnifiedVectorFormat l_child_format; + UnifiedVectorFormat r_child_format; + + l_child.ToUnifiedFormat(l_size, l_child_format); + r_child.ToUnifiedFormat(r_size, r_child_format); + + // Create the sort keys for the list elements + Vector l_sortkey_vec(LogicalType::BLOB, l_size); + Vector r_sortkey_vec(LogicalType::BLOB, r_size); + + const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + + CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + + const auto l_sortkey_ptr = FlatVector::GetData(l_sortkey_vec); + const auto r_sortkey_ptr = FlatVector::GetData(r_sortkey_vec); + + string_set_t set; + + BinaryExecutor::Execute( + l_vec, r_vec, result, args.size(), [&](const list_entry_t &l_list, const list_entry_t &r_list) { + // Short circuit if either list is empty + if (l_list.length == 0 || r_list.length == 0) { + return false; + } + + auto build_list = l_list; + auto probe_list = r_list; + + auto build_data = l_sortkey_ptr; + auto probe_data = r_sortkey_ptr; + + auto build_format = &l_child_format; + auto probe_format = &r_child_format; + + // Use the smaller list to build the set + if (r_list.length < l_list.length) { + + build_list = r_list; + probe_list = l_list; + + build_data = r_sortkey_ptr; + probe_data = l_sortkey_ptr; + + build_format = &r_child_format; + probe_format = &l_child_format; + } + + // Reset the set + set.clear(); + + // Build the set + for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { + const auto entry_idx = build_format->sel->get_index(idx); + if (build_format->validity.RowIsValid(entry_idx)) { + set.insert(build_data[entry_idx]); + } + } + // Probe the set + for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { + const auto entry_idx = probe_format->sel->get_index(idx); + if (probe_format->validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) != set.end()) { + return true; + } + } + return false; + }); +} + +static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { + + const auto &func_expr = state.expr.Cast(); + const auto swap = func_expr.function.name == "<@"; + + auto &l_vec = args.data[swap ? 1 : 0]; + auto &r_vec = args.data[swap ? 0 : 1]; + + if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL && + ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = true; + return; + } + + const auto l_size = ListVector::GetListSize(l_vec); + const auto r_size = ListVector::GetListSize(r_vec); + + auto &l_child = ListVector::GetEntry(l_vec); + auto &r_child = ListVector::GetEntry(r_vec); + + // Setup unified formats for the list elements + UnifiedVectorFormat build_format; + UnifiedVectorFormat probe_format; + + l_child.ToUnifiedFormat(l_size, build_format); + r_child.ToUnifiedFormat(r_size, probe_format); + + // Create the sort keys for the list elements + Vector l_sortkey_vec(LogicalType::BLOB, l_size); + Vector r_sortkey_vec(LogicalType::BLOB, r_size); + + const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + + CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + + const auto build_data = FlatVector::GetData(l_sortkey_vec); + const auto probe_data = FlatVector::GetData(r_sortkey_vec); + + string_set_t set; + + BinaryExecutor::Execute( + l_vec, r_vec, result, args.size(), [&](const list_entry_t &build_list, const list_entry_t &probe_list) { + // Short circuit if the probe list is empty + if (probe_list.length == 0) { + return true; + } + + // Reset the set + set.clear(); + + // Build the set + for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { + const auto entry_idx = build_format.sel->get_index(idx); + if (build_format.validity.RowIsValid(entry_idx)) { + set.insert(build_data[entry_idx]); + } + } + + // Probe the set + for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { + const auto entry_idx = probe_format.sel->get_index(idx); + if (probe_format.validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) == set.end()) { + return false; + } + } + return true; + }); +} + +ScalarFunction ListHasAnyFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, + ListHasAnyFunction, ListHasAnyOrAllBind); + return fun; +} + +ScalarFunction ListHasAllFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, + ListHasAllFunction, ListHasAnyOrAllBind); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/src/core_functions/scalar/list/list_reduce.cpp index b58a7412..a3b4e01d 100644 --- a/src/duckdb/src/core_functions/scalar/list/list_reduce.cpp +++ b/src/duckdb/src/core_functions/scalar/list/list_reduce.cpp @@ -6,7 +6,8 @@ namespace duckdb { struct ReduceExecuteInfo { - ReduceExecuteInfo(LambdaFunctions::LambdaInfo &info, ClientContext &context) : left_slice(*info.child_vector) { + ReduceExecuteInfo(LambdaFunctions::LambdaInfo &info, ClientContext &context) + : left_slice(make_uniq(*info.child_vector)) { SelectionVector left_vector(info.row_count); active_rows.Resize(0, info.row_count); active_rows.SetAllValid(info.row_count); @@ -25,19 +26,18 @@ struct ReduceExecuteInfo { left_vector.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset); reduced_row_idx++; } else { - // Remove the invalid rows - info.result_validity->SetInvalid(original_row_idx); + // Set the row as invalid and remove it from the active rows. + FlatVector::SetNull(info.result, original_row_idx, true); active_rows.SetInvalid(original_row_idx); } } - - left_slice.Slice(left_vector, reduced_row_idx); + left_slice->Slice(left_vector, reduced_row_idx); if (info.has_index) { input_types.push_back(LogicalType::BIGINT); } - input_types.push_back(left_slice.GetType()); - input_types.push_back(left_slice.GetType()); + input_types.push_back(left_slice->GetType()); + input_types.push_back(left_slice->GetType()); for (auto &entry : info.column_infos) { input_types.push_back(entry.vector.get().GetType()); } @@ -45,7 +45,7 @@ struct ReduceExecuteInfo { expr_executor = make_uniq(context, *info.lambda_expr); }; ValidityMask active_rows; - Vector left_slice; + unique_ptr left_slice; unique_ptr expr_executor; vector input_types; @@ -83,12 +83,14 @@ static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFu right_sel.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset + loops + 1); execute_info.left_sel.set_index(reduced_row_idx, valid_row_idx); execute_info.active_rows_sel.set_index(reduced_row_idx, original_row_idx); - reduced_row_idx++; + } else { execute_info.active_rows.SetInvalid(original_row_idx); - info.result.SetValue(original_row_idx, execute_info.left_slice.GetValue(valid_row_idx)); + auto val = execute_info.left_slice->GetValue(valid_row_idx); + info.result.SetValue(original_row_idx, val); } + original_row_idx++; valid_row_idx++; } @@ -102,7 +104,7 @@ static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFu Vector index_vector(Value::BIGINT(UnsafeNumericCast(loops + 1))); // slice the left and right slice - execute_info.left_slice.Slice(execute_info.left_slice, execute_info.left_sel, reduced_row_idx); + execute_info.left_slice->Slice(*execute_info.left_slice, execute_info.left_sel, reduced_row_idx); Vector right_slice(*info.child_vector, right_sel, reduced_row_idx); // create the input chunk @@ -114,7 +116,7 @@ static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFu if (info.has_index) { input_chunk.data[0].Reference(index_vector); } - input_chunk.data[slice_offset + 1].Reference(execute_info.left_slice); + input_chunk.data[slice_offset + 1].Reference(*execute_info.left_slice); input_chunk.data[slice_offset].Reference(right_slice); // add the other columns @@ -132,16 +134,16 @@ static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFu result_chunk.Reset(); result_chunk.SetCardinality(reduced_row_idx); - execute_info.expr_executor->Execute(input_chunk, result_chunk); - // use the result chunk to update the left slice - execute_info.left_slice.Reference(result_chunk.data[0]); + // We need to copy the result into left_slice to avoid data loss due to vector.Reference(...). + // Otherwise, we only keep the data of the previous iteration alive, not that of previous iterations. + execute_info.left_slice = make_uniq(result_chunk.data[0].GetType(), reduced_row_idx); + VectorOperations::Copy(result_chunk.data[0], *execute_info.left_slice, reduced_row_idx, 0, 0); return false; } -void LambdaFunctions::ListReduceFunction(duckdb::DataChunk &args, duckdb::ExpressionState &state, - duckdb::Vector &result) { +void LambdaFunctions::ListReduceFunction(DataChunk &args, ExpressionState &state, Vector &result) { // Initializes the left slice from the list entries, active rows, the expression executor and the input types bool completed = false; LambdaFunctions::LambdaInfo info(args, state, result, completed); @@ -160,16 +162,15 @@ void LambdaFunctions::ListReduceFunction(duckdb::DataChunk &args, duckdb::Expres DataChunk even_result_chunk; even_result_chunk.Initialize(Allocator::DefaultAllocator(), {info.lambda_expr->return_type}); + // Execute reduce until all rows are finished. idx_t loops = 0; bool end = false; - // Execute reduce until all rows are finished while (!end) { auto &result_chunk = loops % 2 ? odd_result_chunk : even_result_chunk; auto &spare_result_chunk = loops % 2 ? even_result_chunk : odd_result_chunk; end = ExecuteReduce(loops, execute_info, info, result_chunk); spare_result_chunk.Reset(); - loops++; } diff --git a/src/duckdb/src/core_functions/scalar/list/list_sort.cpp b/src/duckdb/src/core_functions/scalar/list/list_sort.cpp index f206c509..0fbe54ba 100644 --- a/src/duckdb/src/core_functions/scalar/list/list_sort.cpp +++ b/src/duckdb/src/core_functions/scalar/list/list_sort.cpp @@ -130,8 +130,6 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re // get the child vector auto lists_size = ListVector::GetListSize(sort_result_vec); auto &child_vector = ListVector::GetEntry(sort_result_vec); - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(lists_size, child_data); // get the lists data UnifiedVectorFormat lists_data; diff --git a/src/duckdb/src/core_functions/scalar/list/list_value.cpp b/src/duckdb/src/core_functions/scalar/list/list_value.cpp index e2ae537f..cc7a5df6 100644 --- a/src/duckdb/src/core_functions/scalar/list/list_value.cpp +++ b/src/duckdb/src/core_functions/scalar/list/list_value.cpp @@ -11,17 +11,49 @@ namespace duckdb { -static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto &child_type = ListType::GetChildType(result.GetType()); +struct ListValueAssign { + template + static T Assign(const T &input, Vector &result) { + return input; + } +}; - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); +struct ListValueStringAssign { + template + static T Assign(const T &input, Vector &result) { + return StringVector::AddStringOrBlob(result, input); + } +}; + +template +static void TemplatedListValueFunction(DataChunk &args, Vector &result) { + idx_t list_size = args.ColumnCount(); + ListVector::Reserve(result, args.size() * list_size); + auto result_data = FlatVector::GetData(result); + auto &list_child = ListVector::GetEntry(result); + auto child_data = FlatVector::GetData(list_child); + auto &child_validity = FlatVector::Validity(list_child); + + auto unified_format = args.ToUnifiedFormat(); + for (idx_t r = 0; r < args.size(); r++) { + for (idx_t c = 0; c < list_size; c++) { + auto input_idx = unified_format[c].sel->get_index(r); + auto result_idx = r * list_size + c; + auto input_data = UnifiedVectorFormat::GetData(unified_format[c]); + if (unified_format[c].validity.RowIsValid(input_idx)) { + child_data[result_idx] = OP::template Assign(input_data[input_idx], list_child); + } else { + child_validity.SetInvalid(result_idx); + } } + result_data[r].offset = r * list_size; + result_data[r].length = list_size; } + ListVector::SetListSize(result, args.size() * list_size); +} +static void TemplatedListValueFunctionFallback(DataChunk &args, Vector &result) { + auto &child_type = ListType::GetChildType(result.GetType()); auto result_data = FlatVector::GetData(result); for (idx_t i = 0; i < args.size(); i++) { result_data[i].offset = ListVector::GetListSize(result); @@ -31,7 +63,73 @@ static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &r } result_data[i].length = args.ColumnCount(); } - result.Verify(args.size()); +} + +static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (args.ColumnCount() == 0) { + // no columns - early out - result is a constant empty list + auto result_data = FlatVector::GetData(result); + result_data[0].length = 0; + result_data[0].offset = 0; + return; + } + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + } + auto &result_type = ListVector::GetEntry(result).GetType(); + switch (result_type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT16: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT32: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT64: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT8: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT16: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT32: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT64: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT128: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT128: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::FLOAT: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::DOUBLE: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INTERVAL: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::VARCHAR: + TemplatedListValueFunction(args, result); + break; + default: { + TemplatedListValueFunctionFallback(args, result); + break; + } + } } template diff --git a/src/duckdb/src/core_functions/scalar/map/map_contains.cpp b/src/duckdb/src/core_functions/scalar/map/map_contains.cpp new file mode 100644 index 00000000..19a46015 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/map/map_contains.cpp @@ -0,0 +1,56 @@ +#include "duckdb/core_functions/scalar/map_functions.hpp" +#include "duckdb/function/scalar/list/contains_or_position.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +static void MapContainsFunction(DataChunk &input, ExpressionState &state, Vector &result) { + const auto count = input.size(); + + auto &map_vec = input.data[0]; + auto &key_vec = MapVector::GetKeys(map_vec); + auto &arg_vec = input.data[1]; + + ListSearchOp(map_vec, key_vec, arg_vec, result, count); + + if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr MapContainsBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + + const auto &map = arguments[0]->return_type; + const auto &key = arguments[1]->return_type; + + if (map.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + + if (key.id() == LogicalTypeId::UNKNOWN) { + // Infer the argument type from the map type + bound_function.arguments[0] = map; + bound_function.arguments[1] = MapType::KeyType(map); + } else { + LogicalType max_child_type; + if (!LogicalType::TryGetMaxLogicalType(context, MapType::KeyType(map), key, max_child_type)) { + throw BinderException( + "%s: Cannot match element of type '%s' in a map of type '%s' - an explicit cast is required", + bound_function.name, key.ToString(), map.ToString()); + } + + bound_function.arguments[0] = LogicalType::MAP(max_child_type, MapType::ValueType(map)); + bound_function.arguments[1] = max_child_type; + } + return nullptr; +} + +ScalarFunction MapContainsFun::GetFunction() { + + ScalarFunction fun("map_contains", {LogicalType::MAP(LogicalType::ANY, LogicalType::ANY), LogicalType::ANY}, + LogicalType::BOOLEAN, MapContainsFunction, MapContainsBind); + return fun; +} +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/map/map_extract.cpp b/src/duckdb/src/core_functions/scalar/map/map_extract.cpp index 9cf1ca10..79056cd0 100644 --- a/src/duckdb/src/core_functions/scalar/map/map_extract.cpp +++ b/src/duckdb/src/core_functions/scalar/map/map_extract.cpp @@ -4,125 +4,9 @@ #include "duckdb/parser/expression/bound_expression.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/function/scalar/nested_functions.hpp" - +#include "duckdb/function/scalar/list/contains_or_position.hpp" namespace duckdb { -struct MapKeyArgFunctor { - // MAP is a LIST(STRUCT(K,V)) - // meaning the MAP itself is a List, but the child vector that we're interested in (the keys) - // are a level deeper than the initial child vector - - static Vector &GetList(Vector &map) { - return map; - } - static idx_t GetListSize(Vector &map) { - return ListVector::GetListSize(map); - } - static Vector &GetEntry(Vector &map) { - return MapVector::GetKeys(map); - } -}; - -void FillResult(Vector &map, Vector &offsets, Vector &result, idx_t count) { - UnifiedVectorFormat map_data; - map.ToUnifiedFormat(count, map_data); - - UnifiedVectorFormat offset_data; - offsets.ToUnifiedFormat(count, offset_data); - - auto result_data = FlatVector::GetData(result); - auto entry_count = ListVector::GetListSize(map); - auto &values_entries = MapVector::GetValues(map); - UnifiedVectorFormat values_entry_data; - // Note: this vector can have a different size than the map - values_entries.ToUnifiedFormat(entry_count, values_entry_data); - - for (idx_t row = 0; row < count; row++) { - idx_t offset_idx = offset_data.sel->get_index(row); - auto offset = UnifiedVectorFormat::GetData(offset_data)[offset_idx]; - - // Get the current size of the list, for the offset - idx_t current_offset = ListVector::GetListSize(result); - if (!offset_data.validity.RowIsValid(offset_idx) || !offset) { - // Set the entry data for this result row - auto &entry = result_data[row]; - entry.length = 0; - entry.offset = current_offset; - continue; - } - // All list indices start at 1, reduce by 1 to get the actual index - offset--; - - // Get the 'values' list entry corresponding to the offset - idx_t value_index = map_data.sel->get_index(row); - auto &value_list_entry = UnifiedVectorFormat::GetData(map_data)[value_index]; - - // Add the values to the result - idx_t list_offset = value_list_entry.offset + UnsafeNumericCast(offset); - // All keys are unique, only one will ever match - idx_t length = 1; - ListVector::Append(result, values_entries, length + list_offset, list_offset); - - // Set the entry data for this result row - auto &entry = result_data[row]; - entry.length = length; - entry.offset = current_offset; - } -} - -static bool ArgumentIsConstantNull(Vector &argument) { - return argument.GetType().id() == LogicalTypeId::SQLNULL; -} - -static void MapExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data.size() == 2); - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto &map = args.data[0]; - auto &key = args.data[1]; - - idx_t tuple_count = args.size(); - // Optimization: because keys are not allowed to be NULL, we can early-out - if (ArgumentIsConstantNull(map) || ArgumentIsConstantNull(key)) { - //! We don't need to look through the map if the 'key' to look for is NULL - ListVector::SetListSize(result, 0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto list_data = ConstantVector::GetData(result); - list_data->offset = 0; - list_data->length = 0; - result.Verify(tuple_count); - return; - } - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); - - UnifiedVectorFormat map_data; - - // Create the chunk we'll feed to ListPosition - DataChunk list_position_chunk; - vector chunk_types; - chunk_types.reserve(2); - chunk_types.push_back(map.GetType()); - chunk_types.push_back(key.GetType()); - list_position_chunk.InitializeEmpty(chunk_types.begin(), chunk_types.end()); - - // Populate it with the map keys list and the key vector - list_position_chunk.data[0].Reference(map); - list_position_chunk.data[1].Reference(key); - list_position_chunk.SetCardinality(tuple_count); - - Vector position_vector(LogicalType::LIST(LogicalType::INTEGER), tuple_count); - // We can pass around state as it's not used by ListPositionFunction anyways - ListContainsOrPosition(list_position_chunk, position_vector); - - FillResult(map, position_vector, result, tuple_count); - - if (tuple_count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(tuple_count); -} - static unique_ptr MapExtractBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 2) { @@ -151,8 +35,79 @@ static unique_ptr MapExtractBind(ClientContext &context, ScalarFun return make_uniq(bound_function.return_type); } +static void MapExtractFunc(DataChunk &args, ExpressionState &state, Vector &result) { + const auto count = args.size(); + + auto &map_vec = args.data[0]; + auto &arg_vec = args.data[1]; + + const auto map_is_null = map_vec.GetType().id() == LogicalTypeId::SQLNULL; + const auto arg_is_null = arg_vec.GetType().id() == LogicalTypeId::SQLNULL; + + if (map_is_null || arg_is_null) { + // Short-circuit if either the map or the arg is NULL + ListVector::SetListSize(result, 0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = {0, 0}; + result.Verify(count); + return; + } + + auto &key_vec = MapVector::GetKeys(map_vec); + auto &val_vec = MapVector::GetValues(map_vec); + + // Collect the matching positions + Vector pos_vec(LogicalType::INTEGER, count); + ListSearchOp(map_vec, key_vec, arg_vec, pos_vec, args.size()); + + UnifiedVectorFormat val_format; + UnifiedVectorFormat pos_format; + UnifiedVectorFormat lst_format; + + val_vec.ToUnifiedFormat(ListVector::GetListSize(map_vec), val_format); + pos_vec.ToUnifiedFormat(count, pos_format); + map_vec.ToUnifiedFormat(count, lst_format); + + const auto pos_data = UnifiedVectorFormat::GetData(pos_format); + const auto inc_list_data = ListVector::GetData(map_vec); + const auto out_list_data = ListVector::GetData(result); + + idx_t offset = 0; + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto lst_idx = lst_format.sel->get_index(row_idx); + if (!lst_format.validity.RowIsValid(lst_idx)) { + FlatVector::SetNull(result, row_idx, true); + continue; + } + + auto &inc_list = inc_list_data[lst_idx]; + auto &out_list = out_list_data[row_idx]; + + const auto pos_idx = pos_format.sel->get_index(row_idx); + if (!pos_format.validity.RowIsValid(pos_idx)) { + // We didnt find the key in the map, so return an empty list + out_list.offset = offset; + out_list.length = 0; + continue; + } + + // Compute the actual position of the value in the map value vector + const auto pos = inc_list.offset + UnsafeNumericCast(pos_data[pos_idx] - 1); + out_list.offset = offset; + out_list.length = 1; + ListVector::Append(result, val_vec, pos + 1, pos); + offset++; + } + + if (args.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(count); +} + ScalarFunction MapExtractFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunction, MapExtractBind); + ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunc, MapExtractBind); fun.varargs = LogicalType::ANY; fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; diff --git a/src/duckdb/src/core_functions/scalar/math/numeric.cpp b/src/duckdb/src/core_functions/scalar/math/numeric.cpp index 4a6055a9..1c47fbd9 100644 --- a/src/duckdb/src/core_functions/scalar/math/numeric.cpp +++ b/src/duckdb/src/core_functions/scalar/math/numeric.cpp @@ -516,7 +516,7 @@ struct RoundOperatorPrecision { return input; } } - return UnsafeNumericCast(rounded_value); + return LossyNumericCast(rounded_value); } }; @@ -527,7 +527,7 @@ struct RoundOperator { if (std::isinf(rounded_value) || std::isnan(rounded_value)) { return input; } - return UnsafeNumericCast(rounded_value); + return LossyNumericCast(rounded_value); } }; @@ -1128,6 +1128,102 @@ ScalarFunction AcosFun::GetFunction() { ScalarFunction::UnaryFunction>); } +//===--------------------------------------------------------------------===// +// cosh +//===--------------------------------------------------------------------===// +struct CoshOperator { + template + static inline TR Operation(TA input) { + return (double)std::cosh(input); + } +}; + +ScalarFunction CoshFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// acosh +//===--------------------------------------------------------------------===// +struct AcoshOperator { + template + static inline TR Operation(TA input) { + return (double)std::acosh(input); + } +}; + +ScalarFunction AcoshFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// sinh +//===--------------------------------------------------------------------===// +struct SinhOperator { + template + static inline TR Operation(TA input) { + return (double)std::sinh(input); + } +}; + +ScalarFunction SinhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// asinh +//===--------------------------------------------------------------------===// +struct AsinhOperator { + template + static inline TR Operation(TA input) { + return (double)std::asinh(input); + } +}; + +ScalarFunction AsinhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// tanh +//===--------------------------------------------------------------------===// +struct TanhOperator { + template + static inline TR Operation(TA input) { + return (double)std::tanh(input); + } +}; + +ScalarFunction TanhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// atanh +//===--------------------------------------------------------------------===// +struct AtanhOperator { + template + static inline TR Operation(TA input) { + if (input < -1 || input > 1) { + throw InvalidInputException("ATANH is undefined outside [-1,1]"); + } + if (input == -1 || input == 1) { + return INFINITY; + } + return (double)std::atanh(input); + } +}; + +ScalarFunction AtanhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + //===--------------------------------------------------------------------===// // cot //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp index d9bb5895..6e9415aa 100644 --- a/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp +++ b/src/duckdb/src/core_functions/scalar/operators/bitwise.cpp @@ -206,11 +206,10 @@ ScalarFunctionSet BitwiseNotFun::GetFunctions() { //===--------------------------------------------------------------------===// // << [bitwise_left_shift] //===--------------------------------------------------------------------===// - struct BitwiseShiftLeftOperator { template static inline TR Operation(TA input, TB shift) { - TA max_shift = TA(sizeof(TA) * 8); + TA max_shift = TA(sizeof(TA) * 8) + (NumericLimits::IsSigned() ? 0 : 1); if (input < 0) { throw OutOfRangeException("Cannot left-shift negative number %s", NumericHelper::ToString(input)); } diff --git a/src/duckdb/src/core_functions/scalar/random/setseed.cpp b/src/duckdb/src/core_functions/scalar/random/setseed.cpp index 32965cf1..a4e1d01d 100644 --- a/src/duckdb/src/core_functions/scalar/random/setseed.cpp +++ b/src/duckdb/src/core_functions/scalar/random/setseed.cpp @@ -39,7 +39,7 @@ static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &res if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { throw InvalidInputException("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); } - auto norm_seed = NumericCast((input_seeds[i] + 1.0) * half_max); + auto norm_seed = LossyNumericCast((input_seeds[i] + 1.0) * half_max); random_engine.SetSeed(norm_seed); } diff --git a/src/duckdb/src/core_functions/scalar/string/bar.cpp b/src/duckdb/src/core_functions/scalar/string/bar.cpp index e9cd400c..b571e7ac 100644 --- a/src/duckdb/src/core_functions/scalar/string/bar.cpp +++ b/src/duckdb/src/core_functions/scalar/string/bar.cpp @@ -40,7 +40,7 @@ static string_t BarScalarFunction(double x, double min, double max, double max_w result.clear(); - auto width_as_int = NumericCast(width * PARTIAL_BLOCKS_COUNT); + auto width_as_int = LossyNumericCast(width * PARTIAL_BLOCKS_COUNT); idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); for (idx_t i = 0; i < full_blocks_count; i++) { result += FULL_BLOCK; diff --git a/src/duckdb/src/core_functions/scalar/string/hex.cpp b/src/duckdb/src/core_functions/scalar/string/hex.cpp index f399b65c..6f982f26 100644 --- a/src/duckdb/src/core_functions/scalar/string/hex.cpp +++ b/src/duckdb/src/core_functions/scalar/string/hex.cpp @@ -42,7 +42,7 @@ static void WriteHugeIntHexBytes(T x, char *&output, idx_t buffer_size) { static void WriteBinBytes(uint64_t x, char *&output, idx_t buffer_size) { idx_t offset = buffer_size; for (; offset >= 1; offset -= 1) { - *output = ((x >> (offset - 1)) & 0x01) + '0'; + *output = NumericCast(((x >> (offset - 1)) & 0x01) + '0'); output++; } } @@ -392,6 +392,8 @@ ScalarFunctionSet HexFun::GetFunctions() { ScalarFunctionSet to_hex; to_hex.AddFunction( ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToHexFunction)); to_hex.AddFunction( ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, ToHexFunction)); to_hex.AddFunction( @@ -414,6 +416,8 @@ ScalarFunctionSet BinFun::GetFunctions() { to_binary.AddFunction( ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction( + ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToBinaryFunction)); to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); to_binary.AddFunction( diff --git a/src/duckdb/src/core_functions/scalar/string/md5.cpp b/src/duckdb/src/core_functions/scalar/string/md5.cpp index 6e7ac124..399e3a90 100644 --- a/src/duckdb/src/core_functions/scalar/string/md5.cpp +++ b/src/duckdb/src/core_functions/scalar/string/md5.cpp @@ -30,19 +30,6 @@ struct MD5Number128Operator { } }; -template -struct MD5Number64Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; - - MD5Context context; - context.Add(input); - context.Finish(digest); - return *reinterpret_cast(&digest[lower ? 8 : 0]); - } -}; - static void MD5Function(DataChunk &args, ExpressionState &state, Vector &result) { auto &input = args.data[0]; @@ -55,32 +42,18 @@ static void MD5NumberFunction(DataChunk &args, ExpressionState &state, Vector &r UnaryExecutor::Execute(input, result, args.size()); } -static void MD5NumberUpperFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute>(input, result, args.size()); -} - -static void MD5NumberLowerFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute>(input, result, args.size()); -} - -ScalarFunction MD5Fun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, MD5Function); -} - -ScalarFunction MD5NumberFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::HUGEINT, MD5NumberFunction); -} - -ScalarFunction MD5NumberUpperFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberUpperFunction); +ScalarFunctionSet MD5Fun::GetFunctions() { + ScalarFunctionSet set("md5"); + set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, MD5Function)); + set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, MD5Function)); + return set; } -ScalarFunction MD5NumberLowerFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::UBIGINT, MD5NumberLowerFunction); +ScalarFunctionSet MD5NumberFun::GetFunctions() { + ScalarFunctionSet set("md5_number"); + set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::HUGEINT, MD5NumberFunction)); + set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::HUGEINT, MD5NumberFunction)); + return set; } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/printf.cpp b/src/duckdb/src/core_functions/scalar/string/printf.cpp index b71bedef..8b670b5a 100644 --- a/src/duckdb/src/core_functions/scalar/string/printf.cpp +++ b/src/duckdb/src/core_functions/scalar/string/printf.cpp @@ -27,15 +27,26 @@ unique_ptr BindPrintfFunction(ClientContext &context, ScalarFuncti for (idx_t i = 1; i < arguments.size(); i++) { switch (arguments[i]->return_type.id()) { case LogicalTypeId::BOOLEAN: + bound_function.arguments.emplace_back(LogicalType::BOOLEAN); + break; case LogicalTypeId::TINYINT: case LogicalTypeId::SMALLINT: case LogicalTypeId::INTEGER: case LogicalTypeId::BIGINT: + bound_function.arguments.emplace_back(LogicalType::BIGINT); + break; + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + bound_function.arguments.emplace_back(LogicalType::UBIGINT); + break; case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: + bound_function.arguments.emplace_back(LogicalType::DOUBLE); + break; case LogicalTypeId::VARCHAR: - // these types are natively supported - bound_function.arguments.push_back(arguments[i]->return_type); + bound_function.arguments.push_back(LogicalType::VARCHAR); break; case LogicalTypeId::DECIMAL: // decimal type: add cast to double @@ -125,6 +136,11 @@ static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &resu format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); break; } + case LogicalTypeId::UBIGINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } case LogicalTypeId::FLOAT: { auto arg_data = FlatVector::GetData(col); format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); diff --git a/src/duckdb/src/core_functions/scalar/string/repeat.cpp b/src/duckdb/src/core_functions/scalar/string/repeat.cpp index b124c65b..31318290 100644 --- a/src/duckdb/src/core_functions/scalar/string/repeat.cpp +++ b/src/duckdb/src/core_functions/scalar/string/repeat.cpp @@ -31,11 +31,56 @@ static void RepeatFunction(DataChunk &args, ExpressionState &, Vector &result) { }); } +unique_ptr RepeatBindFunction(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + switch (arguments[0]->return_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::LIST: + break; + default: + throw NotImplementedException("repeat(list, count) requires a list as parameter"); + } + bound_function.arguments[0] = arguments[0]->return_type; + bound_function.return_type = arguments[0]->return_type; + return nullptr; +} + +static void RepeatListFunction(DataChunk &args, ExpressionState &, Vector &result) { + auto &list_vector = args.data[0]; + auto &cnt_vector = args.data[1]; + + auto &source_child = ListVector::GetEntry(list_vector); + auto &result_child = ListVector::GetEntry(result); + + idx_t current_size = ListVector::GetListSize(result); + BinaryExecutor::Execute( + list_vector, cnt_vector, result, args.size(), [&](list_entry_t list_input, int64_t cnt) { + idx_t copy_count = cnt <= 0 || list_input.length == 0 ? 0 : UnsafeNumericCast(cnt); + idx_t result_length = list_input.length * copy_count; + idx_t new_size = current_size + result_length; + ListVector::Reserve(result, new_size); + list_entry_t result_list; + result_list.offset = current_size; + result_list.length = result_length; + for (idx_t i = 0; i < copy_count; i++) { + // repeat the list contents "cnt" times + VectorOperations::Copy(source_child, result_child, list_input.offset + list_input.length, + list_input.offset, current_size); + current_size += list_input.length; + } + return result_list; + }); + ListVector::SetListSize(result, current_size); +} + ScalarFunctionSet RepeatFun::GetFunctions() { ScalarFunctionSet repeat; for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); } + repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), RepeatListFunction, RepeatBindFunction)); return repeat; } diff --git a/src/duckdb/src/core_functions/scalar/string/reverse.cpp b/src/duckdb/src/core_functions/scalar/string/reverse.cpp index 95c3cf1a..cef1441f 100644 --- a/src/duckdb/src/core_functions/scalar/string/reverse.cpp +++ b/src/duckdb/src/core_functions/scalar/string/reverse.cpp @@ -3,7 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/common/vector_operations/unary_executor.hpp" -#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" #include @@ -23,10 +23,9 @@ static bool StrReverseASCII(const char *input, idx_t n, char *output) { //! Unicode string reverse using grapheme breakers static void StrReverseUnicode(const char *input, idx_t n, char *output) { - utf8proc_grapheme_callback(input, n, [&](size_t start, size_t end) { - memcpy(output + n - end, input + start, end - start); - return true; - }); + for (auto cluster : Utf8Proc::GraphemeClusters(input, n)) { + memcpy(output + n - cluster.end, input + cluster.start, cluster.end - cluster.start); + } } struct ReverseOperator { diff --git a/src/duckdb/src/core_functions/scalar/string/sha1.cpp b/src/duckdb/src/core_functions/scalar/string/sha1.cpp new file mode 100644 index 00000000..82ec9b7a --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/sha1.cpp @@ -0,0 +1,35 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "mbedtls_wrapper.hpp" + +namespace duckdb { + +struct SHA1Operator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto hash = StringVector::EmptyString(result, duckdb_mbedtls::MbedTlsWrapper::SHA1_HASH_LENGTH_TEXT); + + duckdb_mbedtls::MbedTlsWrapper::SHA1State state; + state.AddString(input.GetString()); + state.FinishHex(hash.GetDataWriteable()); + + hash.Finalize(); + return hash; + } +}; + +static void SHA1Function(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + + UnaryExecutor::ExecuteString(input, result, args.size()); +} + +ScalarFunctionSet SHA1Fun::GetFunctions() { + ScalarFunctionSet set("sha1"); + set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA1Function)); + set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, SHA1Function)); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/sha256.cpp b/src/duckdb/src/core_functions/scalar/string/sha256.cpp index efc09c05..32ca5f5c 100644 --- a/src/duckdb/src/core_functions/scalar/string/sha256.cpp +++ b/src/duckdb/src/core_functions/scalar/string/sha256.cpp @@ -25,8 +25,11 @@ static void SHA256Function(DataChunk &args, ExpressionState &state, Vector &resu UnaryExecutor::ExecuteString(input, result, args.size()); } -ScalarFunction SHA256Fun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA256Function); +ScalarFunctionSet SHA256Fun::GetFunctions() { + ScalarFunctionSet set("sha256"); + set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA256Function)); + set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, SHA256Function)); + return set; } } // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/string/url_encode.cpp b/src/duckdb/src/core_functions/scalar/string/url_encode.cpp new file mode 100644 index 00000000..51d49079 --- /dev/null +++ b/src/duckdb/src/core_functions/scalar/string/url_encode.cpp @@ -0,0 +1,49 @@ +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/core_functions/scalar/string_functions.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct URLEncodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_str = input.GetData(); + auto input_size = input.GetSize(); + idx_t result_length = StringUtil::URLEncodeSize(input_str, input_size); + auto result_str = StringVector::EmptyString(result, result_length); + StringUtil::URLEncodeBuffer(input_str, input_size, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +static void URLEncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction UrlEncodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLEncodeFunction); +} + +struct URLDecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_str = input.GetData(); + auto input_size = input.GetSize(); + idx_t result_length = StringUtil::URLDecodeSize(input_str, input_size); + auto result_str = StringVector::EmptyString(result, result_length); + StringUtil::URLDecodeBuffer(input_str, input_size, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +static void URLDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction UrlDecodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLDecodeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp b/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp index b173439f..4e71ea36 100644 --- a/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp +++ b/src/duckdb/src/core_functions/scalar/struct/struct_pack.cpp @@ -18,7 +18,7 @@ static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector & #endif bool all_const = true; auto &child_entries = StructVector::GetEntries(result); - for (size_t i = 0; i < args.ColumnCount(); i++) { + for (idx_t i = 0; i < args.ColumnCount(); i++) { if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { all_const = false; } @@ -26,7 +26,6 @@ static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector & child_entries[i]->Reference(args.data[i]); } result.SetVectorType(all_const ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); - result.Verify(args.size()); } diff --git a/src/duckdb/src/core_functions/scalar/union/union_extract.cpp b/src/duckdb/src/core_functions/scalar/union/union_extract.cpp index fe838cb1..8478ad0f 100644 --- a/src/duckdb/src/core_functions/scalar/union/union_extract.cpp +++ b/src/duckdb/src/core_functions/scalar/union/union_extract.cpp @@ -45,7 +45,9 @@ static unique_ptr UnionExtractBind(ClientContext &context, ScalarF if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } - D_ASSERT(LogicalTypeId::UNION == arguments[0]->return_type.id()); + if (arguments[0]->return_type.id() != LogicalTypeId::UNION) { + throw BinderException("union_extract can only take a union parameter"); + } idx_t union_member_count = UnionType::GetMemberCount(arguments[0]->return_type); if (union_member_count == 0) { throw InternalException("Can't extract something from an empty union"); @@ -88,7 +90,7 @@ static unique_ptr UnionExtractBind(ClientContext &context, ScalarF for (idx_t i = 0; i < union_member_count; i++) { candidates.push_back(UnionType::GetMemberName(arguments[0]->return_type, i)); } - auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); + auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); throw BinderException("Could not find key \"%s\" in union\n%s", key, message); } diff --git a/src/duckdb/src/execution/adaptive_filter.cpp b/src/duckdb/src/execution/adaptive_filter.cpp index 166174d0..ce4025d2 100644 --- a/src/duckdb/src/execution/adaptive_filter.cpp +++ b/src/duckdb/src/execution/adaptive_filter.cpp @@ -1,12 +1,12 @@ #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/execution/adaptive_filter.hpp" #include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/vector.hpp" namespace duckdb { -AdaptiveFilter::AdaptiveFilter(const Expression &expr) - : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { +AdaptiveFilter::AdaptiveFilter(const Expression &expr) : observe_interval(10), execute_interval(20), warmup(true) { auto &conj_expr = expr.Cast(); D_ASSERT(conj_expr.children.size() > 1); for (idx_t idx = 0; idx < conj_expr.children.size(); idx++) { @@ -18,15 +18,34 @@ AdaptiveFilter::AdaptiveFilter(const Expression &expr) right_random_border = 100 * (conj_expr.children.size() - 1); } -AdaptiveFilter::AdaptiveFilter(TableFilterSet *table_filters) - : iteration_count(0), observe_interval(10), execute_interval(20), warmup(true) { - for (auto &table_filter : table_filters->filters) { - permutation.push_back(table_filter.first); +AdaptiveFilter::AdaptiveFilter(const TableFilterSet &table_filters) + : observe_interval(10), execute_interval(20), warmup(true) { + for (idx_t idx = 0; idx < table_filters.filters.size(); idx++) { + permutation.push_back(idx); swap_likeliness.push_back(100); } swap_likeliness.pop_back(); - right_random_border = 100 * (table_filters->filters.size() - 1); + right_random_border = 100 * (table_filters.filters.size() - 1); +} + +AdaptiveFilterState AdaptiveFilter::BeginFilter() const { + if (permutation.size() <= 1) { + return AdaptiveFilterState(); + } + AdaptiveFilterState state; + state.start_time = high_resolution_clock::now(); + return state; } + +void AdaptiveFilter::EndFilter(AdaptiveFilterState state) { + if (permutation.size() <= 1) { + // nothing to permute + return; + } + auto end_time = high_resolution_clock::now(); + AdaptRuntimeStatistics(duration_cast>(end_time - state.start_time).count()); +} + void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { iteration_count++; runtime_sum += duration; @@ -35,7 +54,7 @@ void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { // the last swap was observed if (observe && iteration_count == observe_interval) { // keep swap if runtime decreased, else reverse swap - if (prev_mean - (runtime_sum / iteration_count) <= 0) { + if (prev_mean - (runtime_sum / static_cast(iteration_count)) <= 0) { // reverse swap because runtime didn't decrease std::swap(permutation[swap_idx], permutation[swap_idx + 1]); @@ -54,11 +73,11 @@ void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { runtime_sum = 0.0; } else if (!observe && iteration_count == execute_interval) { // save old mean to evaluate swap - prev_mean = runtime_sum / iteration_count; + prev_mean = runtime_sum / static_cast(iteration_count); // get swap index and swap likeliness - std::uniform_int_distribution distribution(1, NumericCast(right_random_border)); // a <= i <= b - auto random_number = UnsafeNumericCast(distribution(generator) - 1); + // a <= i <= b + auto random_number = generator.NextRandomInteger(1, NumericCast(right_random_border)); swap_idx = random_number / 100; // index to be swapped idx_t likeliness = random_number - 100 * swap_idx; // random number between [0, 100) diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp index 5e944fba..e09fd9b7 100644 --- a/src/duckdb/src/execution/aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/aggregate_hashtable.cpp @@ -9,6 +9,7 @@ #include "duckdb/common/types/row/tuple_data_iterator.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/ht_entry.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" namespace duckdb { @@ -122,7 +123,7 @@ idx_t GroupedAggregateHashTable::InitialCapacity() { idx_t GroupedAggregateHashTable::GetCapacityForCount(idx_t count) { count = MaxValue(InitialCapacity(), count); - return NextPowerOfTwo(NumericCast(static_cast(count) * LOAD_FACTOR)); + return NextPowerOfTwo(LossyNumericCast(static_cast(count) * LOAD_FACTOR)); } idx_t GroupedAggregateHashTable::Capacity() const { @@ -130,7 +131,7 @@ idx_t GroupedAggregateHashTable::Capacity() const { } idx_t GroupedAggregateHashTable::ResizeThreshold() const { - return NumericCast(static_cast(Capacity()) / LOAD_FACTOR); + return LossyNumericCast(static_cast(Capacity()) / LOAD_FACTOR); } idx_t GroupedAggregateHashTable::ApplyBitMask(hash_t hash) const { @@ -146,7 +147,7 @@ void GroupedAggregateHashTable::Verify() { continue; } auto hash = Load(entry.GetPointer() + hash_offset); - D_ASSERT(entry.GetSalt() == aggr_ht_entry_t::ExtractSalt(hash)); + D_ASSERT(entry.GetSalt() == ht_entry_t::ExtractSalt(hash)); total_count++; } D_ASSERT(total_count == Count()); @@ -154,7 +155,7 @@ void GroupedAggregateHashTable::Verify() { } void GroupedAggregateHashTable::ClearPointerTable() { - std::fill_n(entries, capacity, aggr_ht_entry_t(0)); + std::fill_n(entries, capacity, ht_entry_t::GetEmptyEntry()); } void GroupedAggregateHashTable::ResetCount() { @@ -173,8 +174,8 @@ void GroupedAggregateHashTable::Resize(idx_t size) { } capacity = size; - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(aggr_ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); + entries = reinterpret_cast(hash_map.get()); ClearPointerTable(); bitmask = capacity - 1; @@ -201,7 +202,7 @@ void GroupedAggregateHashTable::Resize(idx_t size) { } auto &entry = entries[entry_idx]; D_ASSERT(!entry.IsOccupied()); - entry.SetSalt(aggr_ht_entry_t::ExtractSalt(hash)); + entry.SetSalt(ht_entry_t::ExtractSalt(hash)); entry.SetPointer(row_location); D_ASSERT(entry.IsOccupied()); } @@ -333,7 +334,7 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V const auto &hash = hashes[r]; ht_offsets[r] = ApplyBitMask(hash); D_ASSERT(ht_offsets[r] == hash % capacity); - hash_salts[r] = aggr_ht_entry_t::ExtractSalt(hash); + hash_salts[r] = ht_entry_t::ExtractSalt(hash); } // we start out with all entries [0, 1, 2, ..., groups.size()] @@ -354,7 +355,7 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V auto &chunk_state = state.append_state.chunk_state; TupleDataCollection::ToUnifiedFormat(chunk_state, state.group_chunk); if (!state.group_data) { - state.group_data = make_unsafe_uniq_array(state.group_chunk.ColumnCount()); + state.group_data = make_unsafe_uniq_array_uninitialized(state.group_chunk.ColumnCount()); } TupleDataCollection::GetVectorData(chunk_state, state.group_data.get()); @@ -380,13 +381,9 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V // Same salt, compare group keys state.group_compare_vector.set_index(need_compare_count++, index); break; - } else { - // Different salts, move to next entry (linear probing) - if (++ht_offset >= capacity) { - ht_offset = 0; - } - continue; } + // Different salts, move to next entry (linear probing) + IncrementAndWrap(ht_offset, bitmask); } else { // Cell is unoccupied, let's claim it // Set salt (also marks as occupied) entry.SetSalt(salt); @@ -439,9 +436,7 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V for (idx_t i = 0; i < no_match_count; i++) { const auto index = state.no_match_vector.get_index(i); auto &ht_offset = ht_offsets[index]; - if (++ht_offset >= capacity) { - ht_offset = 0; - } + IncrementAndWrap(ht_offset, bitmask); } sel_vector = &state.no_match_vector; remaining_entries = no_match_count; diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index 8c70ca22..716672d8 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -249,7 +249,7 @@ static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t for (idx_t i = 0; i < count; i++) { auto bidx = bsel->get_index(i); auto result_idx = sel->get_index(i); - if (bdata[bidx] > 0 && (NO_NULL || mask.RowIsValid(bidx))) { + if ((NO_NULL || mask.RowIsValid(bidx)) && bdata[bidx] > 0) { if (HAS_TRUE_SEL) { true_sel->set_index(true_count++, result_idx); } diff --git a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp index de1fc801..37161cfd 100644 --- a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp @@ -2,7 +2,6 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/execution/adaptive_filter.hpp" -#include "duckdb/common/chrono.hpp" #include @@ -60,8 +59,7 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express if (expr.type == ExpressionType::CONJUNCTION_AND) { // get runtime statistics - auto start_time = high_resolution_clock::now(); - + auto filter_state = state.adaptive_filter->BeginFilter(); const SelectionVector *current_sel = sel; idx_t current_count = count; idx_t false_count = 0; @@ -96,14 +94,12 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express current_sel = true_sel; } } - // adapt runtime statistics - auto end_time = high_resolution_clock::now(); - state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); + state.adaptive_filter->EndFilter(filter_state); return current_count; } else { // get runtime statistics - auto start_time = high_resolution_clock::now(); + auto filter_state = state.adaptive_filter->BeginFilter(); const SelectionVector *current_sel = sel; idx_t current_count = count; @@ -135,8 +131,7 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express } // adapt runtime statistics - auto end_time = high_resolution_clock::now(); - state.adaptive_filter->AdaptRuntimeStatistics(duration_cast>(end_time - start_time).count()); + state.adaptive_filter->EndFilter(filter_state); return result_count; } } diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index a28a0f66..be4beef1 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -1,34 +1,38 @@ #include "duckdb/execution/index/art/art.hpp" #include "duckdb/common/types/conflict_manager.hpp" +#include "duckdb/common/unordered_map.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/base_leaf.hpp" +#include "duckdb/execution/index/art/base_node.hpp" #include "duckdb/execution/index/art/iterator.hpp" #include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/node16.hpp" #include "duckdb/execution/index/art/node256.hpp" -#include "duckdb/execution/index/art/node4.hpp" +#include "duckdb/execution/index/art/node256_leaf.hpp" #include "duckdb/execution/index/art/node48.hpp" #include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/storage/arena_allocator.hpp" #include "duckdb/storage/metadata/metadata_reader.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_io_manager.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" namespace duckdb { struct ARTIndexScanState : public IndexScanState { - - //! Scan predicates (single predicate scan or range scan) + //! The predicates to scan. + //! A single predicate for point lookups, and two predicates for range scans. Value values[2]; - //! Expressions of the scan predicates + //! The expressions over the scan predicates. ExpressionType expressions[2]; bool checked = false; - //! All scanned row IDs - vector result_ids; - Iterator iterator; + //! All scanned row IDs. + unsafe_vector row_ids; }; //===--------------------------------------------------------------------===// @@ -37,40 +41,13 @@ struct ARTIndexScanState : public IndexScanState { ART::ART(const string &name, const IndexConstraintType index_constraint_type, const vector &column_ids, TableIOManager &table_io_manager, const vector> &unbound_expressions, - AttachedDatabase &db, const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, + AttachedDatabase &db, + const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, const IndexStorageInfo &info) : BoundIndex(name, ART::TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db), allocators(allocators_ptr), owns_data(false) { - // initialize all allocators - if (!allocators) { - owns_data = true; - auto &block_manager = table_io_manager.GetIndexBlockManager(); - - array, ALLOCATOR_COUNT> allocator_array = { - make_uniq(sizeof(Prefix), block_manager), - make_uniq(sizeof(Leaf), block_manager), - make_uniq(sizeof(Node4), block_manager), - make_uniq(sizeof(Node16), block_manager), - make_uniq(sizeof(Node48), block_manager), - make_uniq(sizeof(Node256), block_manager)}; - allocators = - make_shared_ptr, ALLOCATOR_COUNT>>(std::move(allocator_array)); - } - - // deserialize lazily - if (info.IsValid()) { - - if (!info.root_block_ptr.IsValid()) { - InitAllocators(info); - - } else { - // old storage file - Deserialize(info.root_block_ptr); - } - } - - // validate the types of the key columns + // FIXME: Use the new byte representation function to support nested types. for (idx_t i = 0; i < types.size(); i++) { switch (types[i]) { case PhysicalType::BOOL: @@ -92,28 +69,61 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co throw InvalidTypeException(logical_types[i], "Invalid type for index key."); } } + + // Initialize the allocators. + SetPrefixCount(info); + if (!allocators) { + owns_data = true; + auto prefix_size = NumericCast(prefix_count) + NumericCast(Prefix::METADATA_SIZE); + auto &block_manager = table_io_manager.GetIndexBlockManager(); + + array, ALLOCATOR_COUNT> allocator_array = { + make_unsafe_uniq(prefix_size, block_manager), + make_unsafe_uniq(sizeof(Leaf), block_manager), + make_unsafe_uniq(sizeof(Node4), block_manager), + make_unsafe_uniq(sizeof(Node16), block_manager), + make_unsafe_uniq(sizeof(Node48), block_manager), + make_unsafe_uniq(sizeof(Node256), block_manager), + make_unsafe_uniq(sizeof(Node7Leaf), block_manager), + make_unsafe_uniq(sizeof(Node15Leaf), block_manager), + make_unsafe_uniq(sizeof(Node256Leaf), block_manager), + }; + allocators = + make_shared_ptr, ALLOCATOR_COUNT>>(std::move(allocator_array)); + } + + if (!info.IsValid()) { + // We create a new ART. + return; + } + + if (info.root_block_ptr.IsValid()) { + // Backwards compatibility. + Deserialize(info.root_block_ptr); + return; + } + + // Set the root node and initialize the allocators. + tree.Set(info.root); + InitAllocators(info); } //===--------------------------------------------------------------------===// -// Initialize Predicate Scans +// Initialize Scans //===--------------------------------------------------------------------===// -//! Initialize a single predicate scan on the index with the given expression and column IDs -static unique_ptr InitializeScanSinglePredicate(const Transaction &transaction, const Value &value, +static unique_ptr InitializeScanSinglePredicate(const Value &value, const ExpressionType expression_type) { - // initialize point lookup auto result = make_uniq(); result->values[0] = value; result->expressions[0] = expression_type; return std::move(result); } -//! Initialize a two predicate scan on the index with the given expression and column IDs -static unique_ptr InitializeScanTwoPredicates(const Transaction &transaction, const Value &low_value, +static unique_ptr InitializeScanTwoPredicates(const Value &low_value, const ExpressionType low_expression_type, const Value &high_value, const ExpressionType high_expression_type) { - // initialize range lookup auto result = make_uniq(); result->values[0] = low_value; result->expressions[0] = low_expression_type; @@ -122,64 +132,64 @@ static unique_ptr InitializeScanTwoPredicates(const Transaction return std::move(result); } -unique_ptr ART::TryInitializeScan(const Transaction &transaction, const Expression &index_expr, - const Expression &filter_expr) { - +unique_ptr ART::TryInitializeScan(const Expression &expr, const Expression &filter_expr) { Value low_value, high_value, equal_value; ExpressionType low_comparison_type = ExpressionType::INVALID, high_comparison_type = ExpressionType::INVALID; - // try to find a matching index for any of the filter expressions - // create a matcher for a comparison with a constant + // Try to find a matching index for any of the filter expressions. ComparisonExpressionMatcher matcher; - // match on a comparison type + // Match on a comparison type. matcher.expr_type = make_uniq(); - // match on a constant comparison with the indexed expression - matcher.matchers.push_back(make_uniq(index_expr)); + // Match on a constant comparison with the indexed expression. + matcher.matchers.push_back(make_uniq(expr)); matcher.matchers.push_back(make_uniq()); - matcher.policy = SetMatcher::Policy::UNORDERED; vector> bindings; - if (matcher.Match(const_cast(filter_expr), bindings)) { // NOLINT: Match does not alter the expr - // range or equality comparison with constant value - // we can use our index here - // bindings[0] = the expression - // bindings[1] = the index expression - // bindings[2] = the constant + auto filter_match = + matcher.Match(const_cast(filter_expr), bindings); // NOLINT: Match does not alter the expr. + if (filter_match) { + // This is a range or equality comparison with a constant value, so we can use the index. + // bindings[0] = the expression + // bindings[1] = the index expression + // bindings[2] = the constant auto &comparison = bindings[0].get().Cast(); auto constant_value = bindings[2].get().Cast().value; auto comparison_type = comparison.type; + if (comparison.left->type == ExpressionType::VALUE_CONSTANT) { - // the expression is on the right side, we flip them around + // The expression is on the right side, we flip the comparison expression. comparison_type = FlipComparisonExpression(comparison_type); } + if (comparison_type == ExpressionType::COMPARE_EQUAL) { - // equality value - // equality overrides any other bounds so we just break here + // An equality value overrides any other bounds. equal_value = constant_value; } else if (comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || comparison_type == ExpressionType::COMPARE_GREATERTHAN) { - // greater than means this is a lower bound + // This is a lower bound. low_value = constant_value; low_comparison_type = comparison_type; } else { - // smaller than means this is an upper bound + // This is an upper bound. high_value = constant_value; high_comparison_type = comparison_type; } + } else if (filter_expr.type == ExpressionType::COMPARE_BETWEEN) { - // BETWEEN expression auto &between = filter_expr.Cast(); - if (!between.input->Equals(index_expr)) { - // expression doesn't match the index expression + if (!between.input->Equals(expr)) { + // The expression does not match the index expression. return nullptr; } + if (between.lower->type != ExpressionType::VALUE_CONSTANT || between.upper->type != ExpressionType::VALUE_CONSTANT) { - // not a constant comparison + // Not a constant expression. return nullptr; } - low_value = (between.lower->Cast()).value; + + low_value = between.lower->Cast().value; low_comparison_type = between.lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO : ExpressionType::COMPARE_GREATERTHAN; high_value = (between.upper->Cast()).value; @@ -187,165 +197,177 @@ unique_ptr ART::TryInitializeScan(const Transaction &transaction between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO : ExpressionType::COMPARE_LESSTHAN; } - if (!equal_value.IsNull() || !low_value.IsNull() || !high_value.IsNull()) { - // we can scan this index using this predicate: try a scan - unique_ptr index_state; - if (!equal_value.IsNull()) { - // equality predicate - index_state = InitializeScanSinglePredicate(transaction, equal_value, ExpressionType::COMPARE_EQUAL); - } else if (!low_value.IsNull() && !high_value.IsNull()) { - // two-sided predicate - index_state = InitializeScanTwoPredicates(transaction, low_value, low_comparison_type, high_value, - high_comparison_type); - } else if (!low_value.IsNull()) { - // less than predicate - index_state = InitializeScanSinglePredicate(transaction, low_value, low_comparison_type); - } else { - D_ASSERT(!high_value.IsNull()); - index_state = InitializeScanSinglePredicate(transaction, high_value, high_comparison_type); - } - return index_state; + // We cannot use an index scan. + if (equal_value.IsNull() && low_value.IsNull() && high_value.IsNull()) { + return nullptr; } - return nullptr; + + // Initialize the index scan state and return it. + if (!equal_value.IsNull()) { + // Equality predicate. + return InitializeScanSinglePredicate(equal_value, ExpressionType::COMPARE_EQUAL); + } + if (!low_value.IsNull() && !high_value.IsNull()) { + // Two-sided predicate. + return InitializeScanTwoPredicates(low_value, low_comparison_type, high_value, high_comparison_type); + } + if (!low_value.IsNull()) { + // Less-than predicate. + return InitializeScanSinglePredicate(low_value, low_comparison_type); + } + // Greater-than predicate. + return InitializeScanSinglePredicate(high_value, high_comparison_type); } //===--------------------------------------------------------------------===// -// Keys +// ART Keys //===--------------------------------------------------------------------===// -template -static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - +template +static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { D_ASSERT(keys.size() >= count); - auto input_data = UnifiedVectorFormat::GetData(idata); + + UnifiedVectorFormat data; + input.ToUnifiedFormat(count, data); + auto input_data = UnifiedVectorFormat::GetData(data); + for (idx_t i = 0; i < count; i++) { - auto idx = idata.sel->get_index(i); - if (idata.validity.RowIsValid(idx)) { - ARTKey::CreateARTKey(allocator, input.GetType(), keys[i], input_data[idx]); - } else { - // we need to possibly reset the former key value in the keys vector - keys[i] = ARTKey(); + auto idx = data.sel->get_index(i); + if (IS_NOT_NULL || data.validity.RowIsValid(idx)) { + ARTKey::CreateARTKey(allocator, keys[i], input_data[idx]); + continue; } + + // We need to reset the key value in the reusable keys vector. + keys[i] = ARTKey(); } } -template -static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, vector &keys) { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); +template +static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { + UnifiedVectorFormat data; + input.ToUnifiedFormat(count, data); + auto input_data = UnifiedVectorFormat::GetData(data); - auto input_data = UnifiedVectorFormat::GetData(idata); for (idx_t i = 0; i < count; i++) { - auto idx = idata.sel->get_index(i); - - // key is not NULL (no previous column entry was NULL) - if (!keys[i].Empty()) { - if (!idata.validity.RowIsValid(idx)) { - // this column entry is NULL, set whole key to NULL - keys[i] = ARTKey(); - } else { - auto other_key = ARTKey::CreateARTKey(allocator, input.GetType(), input_data[idx]); - keys[i].ConcatenateARTKey(allocator, other_key); - } + auto idx = data.sel->get_index(i); + + if (IS_NOT_NULL) { + auto other_key = ARTKey::CreateARTKey(allocator, input_data[idx]); + keys[i].Concat(allocator, other_key); + continue; + } + + // A previous column entry was NULL. + if (keys[i].Empty()) { + continue; } + + // This column entry is NULL, so we set the whole key to NULL. + if (!data.validity.RowIsValid(idx)) { + keys[i] = ARTKey(); + continue; + } + + // Concatenate the keys. + auto other_key = ARTKey::CreateARTKey(allocator, input_data[idx]); + keys[i].Concat(allocator, other_key); } } -void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector &keys) { - // generate keys for the first input column +template +void GenerateKeysInternal(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { switch (input.data[0].GetType().InternalType()) { case PhysicalType::BOOL: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::INT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::INT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::INT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::INT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::INT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::UINT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::UINT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::UINT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::UINT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::UINT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::FLOAT: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::DOUBLE: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; case PhysicalType::VARCHAR: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); break; default: throw InternalException("Invalid type for index"); } + // We concatenate the keys for each remaining column of a compound key. for (idx_t i = 1; i < input.ColumnCount(); i++) { - // for each of the remaining columns, concatenate switch (input.data[i].GetType().InternalType()) { case PhysicalType::BOOL: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::INT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::INT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::INT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::INT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::INT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::UINT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::UINT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::UINT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::UINT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::UINT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::FLOAT: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::DOUBLE: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; case PhysicalType::VARCHAR: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], input.size(), keys); break; default: throw InternalException("Invalid type for index"); @@ -353,195 +375,168 @@ void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, vector +void ART::GenerateKeys<>(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { + GenerateKeysInternal(allocator, input, keys); +} -struct KeySection { - KeySection(idx_t start_p, idx_t end_p, idx_t depth_p, data_t key_byte_p) - : start(start_p), end(end_p), depth(depth_p), key_byte(key_byte_p) {}; - KeySection(idx_t start_p, idx_t end_p, vector &keys, KeySection &key_section) - : start(start_p), end(end_p), depth(key_section.depth + 1), key_byte(keys[end_p].data[key_section.depth]) {}; - idx_t start; - idx_t end; - idx_t depth; - data_t key_byte; -}; +template <> +void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { + GenerateKeysInternal(allocator, input, keys); +} -void GetChildSections(vector &child_sections, vector &keys, KeySection &key_section) { +void ART::GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, + unsafe_vector &row_id_keys) { + GenerateKeys<>(allocator, input, keys); - idx_t child_start_idx = key_section.start; - for (idx_t i = key_section.start + 1; i <= key_section.end; i++) { - if (keys[i - 1].data[key_section.depth] != keys[i].data[key_section.depth]) { - child_sections.emplace_back(child_start_idx, i - 1, keys, key_section); - child_start_idx = i; - } - } - child_sections.emplace_back(child_start_idx, key_section.end, keys, key_section); + DataChunk row_id_chunk; + row_id_chunk.Initialize(Allocator::DefaultAllocator(), vector {LogicalType::ROW_TYPE}, input.size()); + row_id_chunk.data[0].Reference(row_ids); + row_id_chunk.SetCardinality(input.size()); + GenerateKeys<>(allocator, row_id_chunk, row_id_keys); } -bool Construct(ART &art, vector &keys, row_t *row_ids, Node &node, KeySection &key_section, - bool &has_constraint) { +//===--------------------------------------------------------------------===// +// Construct from sorted data. +//===--------------------------------------------------------------------===// - D_ASSERT(key_section.start < keys.size()); - D_ASSERT(key_section.end < keys.size()); - D_ASSERT(key_section.start <= key_section.end); +bool ART::ConstructInternal(const unsafe_vector &keys, const unsafe_vector &row_ids, Node &node, + ARTKeySection §ion) { + D_ASSERT(section.start < keys.size()); + D_ASSERT(section.end < keys.size()); + D_ASSERT(section.start <= section.end); - auto &start_key = keys[key_section.start]; - auto &end_key = keys[key_section.end]; + auto &start = keys[section.start]; + auto &end = keys[section.end]; + D_ASSERT(start.len != 0); - // increment the depth until we reach a leaf or find a mismatching byte - auto prefix_start = key_section.depth; - while (start_key.len != key_section.depth && - start_key.ByteMatches(end_key, UnsafeNumericCast(key_section.depth))) { - key_section.depth++; + // Increment the depth until we reach a leaf or find a mismatching byte. + auto prefix_depth = section.depth; + while (start.len != section.depth && start.ByteMatches(end, section.depth)) { + section.depth++; } - // we reached a leaf, i.e. all the bytes of start_key and end_key match - if (start_key.len == key_section.depth) { - // end_idx is inclusive - auto num_row_ids = key_section.end - key_section.start + 1; - - // check for possible constraint violation - auto single_row_id = num_row_ids == 1; - if (has_constraint && !single_row_id) { + if (start.len == section.depth) { + // We reached a leaf. All the bytes of start_key and end_key match. + auto row_id_count = section.end - section.start + 1; + if (IsUnique() && row_id_count != 1) { return false; } - reference ref_node(node); - Prefix::New(art, ref_node, start_key, UnsafeNumericCast(prefix_start), - UnsafeNumericCast(start_key.len - prefix_start)); - if (single_row_id) { - Leaf::New(ref_node, row_ids[key_section.start]); + reference ref(node); + auto count = UnsafeNumericCast(start.len - prefix_depth); + Prefix::New(*this, ref, start, prefix_depth, count); + if (row_id_count == 1) { + Leaf::New(ref, row_ids[section.start].GetRowId()); } else { - Leaf::New(art, ref_node, row_ids + key_section.start, num_row_ids); + Leaf::New(*this, ref, row_ids, section.start, row_id_count); } return true; } - // create a new node and recurse - - // we will find at least two child entries of this node, otherwise we'd have reached a leaf - vector child_sections; - GetChildSections(child_sections, keys, key_section); + // Create a new node and recurse. + unsafe_vector children; + section.GetChildSections(children, keys); - // set the prefix - reference ref_node(node); - auto prefix_length = key_section.depth - prefix_start; - Prefix::New(art, ref_node, start_key, UnsafeNumericCast(prefix_start), - UnsafeNumericCast(prefix_length)); + // Create the prefix. + reference ref(node); + auto prefix_length = section.depth - prefix_depth; + Prefix::New(*this, ref, start, prefix_depth, prefix_length); - // set the node - auto node_type = Node::GetARTNodeTypeByCount(child_sections.size()); - Node::New(art, ref_node, node_type); - - // recurse on each child section - for (auto &child_section : child_sections) { + // Create the node. + Node::New(*this, ref, Node::GetNodeType(children.size())); + for (auto &child : children) { Node new_child; - auto no_violation = Construct(art, keys, row_ids, new_child, child_section, has_constraint); - Node::InsertChild(art, ref_node, child_section.key_byte, new_child); - if (!no_violation) { + auto success = ConstructInternal(keys, row_ids, new_child, child); + Node::InsertChild(*this, ref, child.key_byte, new_child); + if (!success) { return false; } } return true; } -bool ART::ConstructFromSorted(idx_t count, vector &keys, Vector &row_identifiers) { - - // prepare the row_identifiers - row_identifiers.Flatten(count); - auto row_ids = FlatVector::GetData(row_identifiers); - - auto key_section = KeySection(0, count - 1, 0, 0); - auto has_constraint = IsUnique(); - if (!Construct(*this, keys, row_ids, tree, key_section, has_constraint)) { +bool ART::Construct(unsafe_vector &keys, unsafe_vector &row_ids, const idx_t row_count) { + ARTKeySection section(0, row_count - 1, 0, 0); + if (!ConstructInternal(keys, row_ids, tree, section)) { return false; } #ifdef DEBUG - D_ASSERT(!VerifyAndToStringInternal(true).empty()); - for (idx_t i = 0; i < count; i++) { - D_ASSERT(!keys[i].Empty()); - auto leaf = Lookup(tree, keys[i], 0); - D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_ids[i])); - } + unsafe_vector row_ids_debug; + Iterator it(*this); + it.FindMinimum(tree); + ARTKey empty_key = ARTKey(); + it.Scan(empty_key, NumericLimits().Maximum(), row_ids_debug, false); + D_ASSERT(row_count == row_ids_debug.size()); #endif - return true; } //===--------------------------------------------------------------------===// -// Insert / Verification / Constraint Checking +// Insert and Constraint Checking //===--------------------------------------------------------------------===// -ErrorData ART::Insert(IndexLock &lock, DataChunk &input, Vector &row_ids) { +ErrorData ART::Insert(IndexLock &lock, DataChunk &input, Vector &row_ids) { D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); - D_ASSERT(logical_types[0] == input.data[0].GetType()); + auto row_count = input.size(); - // generate the keys for the given input - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(input.size()); - GenerateKeys(arena_allocator, input, keys); - - // get the corresponding row IDs - row_ids.Flatten(input.size()); - auto row_identifiers = FlatVector::GetData(row_ids); + ArenaAllocator allocator(BufferAllocator::Get(db)); + unsafe_vector keys(row_count); + unsafe_vector row_id_keys(row_count); + GenerateKeyVectors(allocator, input, row_ids, keys, row_id_keys); - // now insert the elements into the index + // Insert the entries into the index. idx_t failed_index = DConstants::INVALID_INDEX; - for (idx_t i = 0; i < input.size(); i++) { + auto was_empty = !tree.HasMetadata(); + for (idx_t i = 0; i < row_count; i++) { if (keys[i].Empty()) { continue; } - - row_t row_id = row_identifiers[i]; - if (!Insert(tree, keys[i], 0, row_id)) { - // failed to insert because of constraint violation + if (!Insert(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus())) { + // Insertion failure due to a constraint violation. failed_index = i; break; } } - // failed to insert because of constraint violation: remove previously inserted entries + // Remove any previously inserted entries. if (failed_index != DConstants::INVALID_INDEX) { for (idx_t i = 0; i < failed_index; i++) { if (keys[i].Empty()) { continue; } - row_t row_id = row_identifiers[i]; - Erase(tree, keys[i], 0, row_id); + Erase(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus()); } } + if (was_empty) { + // All nodes are in-memory. + VerifyAllocationsInternal(); + } + if (failed_index != DConstants::INVALID_INDEX) { - return ErrorData(ConstraintException("PRIMARY KEY or UNIQUE constraint violated: duplicate key \"%s\"", - AppendRowError(input, failed_index))); + auto msg = AppendRowError(input, failed_index); + return ErrorData(ConstraintException("PRIMARY KEY or UNIQUE constraint violated: duplicate key \"%s\"", msg)); } #ifdef DEBUG - for (idx_t i = 0; i < input.size(); i++) { + for (idx_t i = 0; i < row_count; i++) { if (keys[i].Empty()) { continue; } - - auto leaf = Lookup(tree, keys[i], 0); - D_ASSERT(Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); + D_ASSERT(Lookup(tree, keys[i], 0)); } #endif - return ErrorData(); } -ErrorData ART::Append(IndexLock &lock, DataChunk &appended_data, Vector &row_identifiers) { - DataChunk expression_result; - expression_result.Initialize(Allocator::DefaultAllocator(), logical_types); - - // first resolve the expressions for the index - ExecuteExpressions(appended_data, expression_result); - - // now insert into the index - return Insert(lock, expression_result, row_identifiers); +ErrorData ART::Append(IndexLock &lock, DataChunk &input, Vector &row_ids) { + // Execute all column expressions before inserting the data chunk. + DataChunk expr_chunk; + expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expr_chunk); + return Insert(lock, expr_chunk, row_ids); } void ART::VerifyAppend(DataChunk &chunk) { @@ -554,87 +549,102 @@ void ART::VerifyAppend(DataChunk &chunk, ConflictManager &conflict_manager) { CheckConstraintsForChunk(chunk, conflict_manager); } -bool ART::InsertToLeaf(Node &leaf, const row_t &row_id) { +void ART::InsertIntoEmpty(Node &node, const ARTKey &key, const idx_t depth, const ARTKey &row_id, + const GateStatus status) { + D_ASSERT(depth <= key.len); + D_ASSERT(!node.HasMetadata()); - if (IsUnique()) { - return false; + if (status == GateStatus::GATE_SET) { + Leaf::New(node, row_id.GetRowId()); + return; } - Leaf::Insert(*this, leaf, row_id); - return true; + reference ref(node); + auto count = key.len - depth; + + Prefix::New(*this, ref, key, depth, count); + Leaf::New(ref, row_id.GetRowId()); } -bool ART::Insert(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { +bool ART::InsertIntoNode(Node &node, const ARTKey &key, const idx_t depth, const ARTKey &row_id, + const GateStatus status) { + D_ASSERT(depth < key.len); + auto child = node.GetChildMutable(*this, key[depth]); - // node is currently empty, create a leaf here with the key - if (!node.HasMetadata()) { - D_ASSERT(depth <= key.len); - reference ref_node(node); - Prefix::New(*this, ref_node, key, UnsafeNumericCast(depth), - UnsafeNumericCast(key.len - depth)); - Leaf::New(ref_node, row_id); - return true; + // Recurse, if a child exists at key[depth]. + if (child) { + D_ASSERT(child->HasMetadata()); + bool success = Insert(*child, key, depth + 1, row_id, status); + node.ReplaceChild(*this, key[depth], *child); + return success; } - auto node_type = node.GetType(); - - // insert the row ID into this leaf - if (node_type == NType::LEAF || node_type == NType::LEAF_INLINED) { - return InsertToLeaf(node, row_id); + // Create an inlined prefix at key[depth]. + if (status == GateStatus::GATE_SET) { + Node remainder; + auto byte = key[depth]; + auto success = Insert(remainder, key, depth + 1, row_id, status); + Node::InsertChild(*this, node, byte, remainder); + return success; } - if (node_type != NType::PREFIX) { - D_ASSERT(depth < key.len); - auto child = node.GetChildMutable(*this, key[depth]); + // Insert an inlined leaf at key[depth]. + Node leaf; + reference ref(leaf); - // recurse, if a child exists at key[depth] - if (child) { - bool success = Insert(*child, key, depth + 1, row_id); - node.ReplaceChild(*this, key[depth], *child); - return success; - } - - // insert a new leaf node at key[depth] - Node leaf_node; - reference ref_node(leaf_node); - if (depth + 1 < key.len) { - Prefix::New(*this, ref_node, key, UnsafeNumericCast(depth + 1), - UnsafeNumericCast(key.len - depth - 1)); - } - Leaf::New(ref_node, row_id); - Node::InsertChild(*this, node, key[depth], leaf_node); - return true; + // Create the prefix. + if (depth + 1 < key.len) { + auto count = key.len - depth - 1; + Prefix::New(*this, ref, key, depth + 1, count); } - // this is a prefix node, traverse - reference next_node(node); - auto mismatch_position = Prefix::TraverseMutable(*this, next_node, key, depth); + // Create the inlined leaf. + Leaf::New(ref, row_id.GetRowId()); + Node::InsertChild(*this, node, key[depth], leaf); + return true; +} - // prefix matches key - if (next_node.get().GetType() != NType::PREFIX) { - return Insert(next_node, key, depth, row_id); +bool ART::Insert(Node &node, const ARTKey &key, idx_t depth, const ARTKey &row_id, const GateStatus status) { + if (!node.HasMetadata()) { + InsertIntoEmpty(node, key, depth, row_id, status); + return true; } - // prefix does not match the key, we need to create a new Node4; this new Node4 has two children, - // the remaining part of the prefix, and the new leaf - Node remaining_prefix; - auto prefix_byte = Prefix::GetByte(*this, next_node, mismatch_position); - Prefix::Split(*this, next_node, remaining_prefix, mismatch_position); - Node4::New(*this, next_node); - - // insert remaining prefix - Node4::InsertChild(*this, next_node, prefix_byte, remaining_prefix); + // Enter a nested leaf. + if (status == GateStatus::GATE_NOT_SET && node.GetGateStatus() == GateStatus::GATE_SET) { + return Insert(node, row_id, 0, row_id, GateStatus::GATE_SET); + } - // insert new leaf - Node leaf_node; - reference ref_node(leaf_node); - if (depth + 1 < key.len) { - Prefix::New(*this, ref_node, key, UnsafeNumericCast(depth + 1), - UnsafeNumericCast(key.len - depth - 1)); + auto type = node.GetType(); + switch (type) { + case NType::LEAF_INLINED: { + if (IsUnique()) { + return false; + } + Leaf::InsertIntoInlined(*this, node, row_id, depth, status); + return true; + } + case NType::LEAF: { + Leaf::TransformToNested(*this, node); + return Insert(node, key, depth, row_id, status); + } + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: { + auto byte = key[Prefix::ROW_ID_COUNT]; + Node::InsertChild(*this, node, byte); + return true; + } + case NType::NODE_4: + case NType::NODE_16: + case NType::NODE_48: + case NType::NODE_256: + return InsertIntoNode(node, key, depth, row_id, status); + case NType::PREFIX: + return Prefix::Insert(*this, node, key, depth, row_id, status); + default: + throw InternalException("Invalid node type for Insert."); } - Leaf::New(ref_node, row_id); - Node4::InsertChild(*this, next_node, key[depth], leaf_node); - return true; } //===--------------------------------------------------------------------===// @@ -649,356 +659,291 @@ void ART::CommitDrop(IndexLock &index_lock) { } void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { + auto row_count = input.size(); - DataChunk expression; - expression.Initialize(Allocator::DefaultAllocator(), logical_types); + DataChunk expr_chunk; + expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expr_chunk); - // first resolve the expressions - ExecuteExpressions(input, expression); + ArenaAllocator allocator(BufferAllocator::Get(db)); + unsafe_vector keys(row_count); + unsafe_vector row_id_keys(row_count); + GenerateKeyVectors(allocator, expr_chunk, row_ids, keys, row_id_keys); - // then generate the keys for the given input - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(expression.size()); - GenerateKeys(arena_allocator, expression, keys); - - // now erase the elements from the database - row_ids.Flatten(input.size()); - auto row_identifiers = FlatVector::GetData(row_ids); - - for (idx_t i = 0; i < input.size(); i++) { + for (idx_t i = 0; i < row_count; i++) { if (keys[i].Empty()) { continue; } - Erase(tree, keys[i], 0, row_identifiers[i]); + Erase(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus()); + } + + if (!tree.HasMetadata()) { + // No more allocations. + VerifyAllocationsInternal(); } #ifdef DEBUG - // verify that we removed all row IDs - for (idx_t i = 0; i < input.size(); i++) { + for (idx_t i = 0; i < row_count; i++) { if (keys[i].Empty()) { continue; } - auto leaf = Lookup(tree, keys[i], 0); - if (leaf) { - D_ASSERT(!Leaf::ContainsRowId(*this, *leaf, row_identifiers[i])); + if (leaf && leaf->GetType() == NType::LEAF_INLINED) { + D_ASSERT(leaf->GetRowId() != row_id_keys[i].GetRowId()); } } #endif } -void ART::Erase(Node &node, const ARTKey &key, idx_t depth, const row_t &row_id) { - +void ART::Erase(Node &node, reference key, idx_t depth, reference row_id, + GateStatus status) { if (!node.HasMetadata()) { return; } - // handle prefix - reference next_node(node); - if (next_node.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, next_node, key, depth); - if (next_node.get().GetType() == NType::PREFIX) { + // Traverse the prefix. + reference next(node); + if (next.get().GetType() == NType::PREFIX) { + Prefix::TraverseMutable(*this, next, key, depth); + + // Prefixes don't match: nothing to erase. + if (next.get().GetType() == NType::PREFIX && next.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { return; } } - // delete a row ID from a leaf (root is leaf with possible prefix nodes) - if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { - if (Leaf::Remove(*this, next_node, row_id)) { + // Delete the row ID from the leaf. + // This is the root node, which can be a leaf with possible prefix nodes. + if (next.get().GetType() == NType::LEAF_INLINED) { + if (next.get().GetRowId() == row_id.get().GetRowId()) { Node::Free(*this, node); } return; } - D_ASSERT(depth < key.len); - auto child = next_node.get().GetChildMutable(*this, key[depth]); - if (child) { - D_ASSERT(child->HasMetadata()); + // Transform a deprecated leaf. + if (next.get().GetType() == NType::LEAF) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + Leaf::TransformToNested(*this, next); + } - auto temp_depth = depth + 1; - reference child_node(*child); - if (child_node.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, child_node, key, temp_depth); - if (child_node.get().GetType() == NType::PREFIX) { - return; - } - } + // Enter a nested leaf. + if (status == GateStatus::GATE_NOT_SET && next.get().GetGateStatus() == GateStatus::GATE_SET) { + return Erase(next, row_id, 0, row_id, GateStatus::GATE_SET); + } - if (child_node.get().GetType() == NType::LEAF || child_node.get().GetType() == NType::LEAF_INLINED) { - // leaf found, remove entry - if (Leaf::Remove(*this, child_node, row_id)) { - Node::DeleteChild(*this, next_node, node, key[depth]); - } - return; + D_ASSERT(depth < key.get().len); + if (next.get().IsLeafNode()) { + auto byte = key.get()[depth]; + if (next.get().HasByte(*this, byte)) { + Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); } + return; + } - // recurse - Erase(*child, key, depth + 1, row_id); - next_node.get().ReplaceChild(*this, key[depth], *child); + auto child = next.get().GetChildMutable(*this, key.get()[depth]); + if (!child) { + // No child at the byte: nothing to erase. + return; } -} -//===--------------------------------------------------------------------===// -// Point Query (Equal) -//===--------------------------------------------------------------------===// + // Transform a deprecated leaf. + if (child->GetType() == NType::LEAF) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + Leaf::TransformToNested(*this, *child); + } -static ARTKey CreateKey(ArenaAllocator &allocator, PhysicalType type, Value &value) { - D_ASSERT(type == value.type().InternalType()); - switch (type) { - case PhysicalType::BOOL: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT8: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT16: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT32: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT64: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT8: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT16: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT32: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT64: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::INT128: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::UINT128: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::FLOAT: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::DOUBLE: - return ARTKey::CreateARTKey(allocator, value.type(), value); - case PhysicalType::VARCHAR: - return ARTKey::CreateARTKey(allocator, value.type(), value); - default: - throw InternalException("Invalid type for the ART key"); + // Enter a nested leaf. + if (status == GateStatus::GATE_NOT_SET && child->GetGateStatus() == GateStatus::GATE_SET) { + Erase(*child, row_id, 0, row_id, GateStatus::GATE_SET); + if (!child->HasMetadata()) { + Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); + } else { + next.get().ReplaceChild(*this, key.get()[depth], *child); + } + return; } -} -bool ART::SearchEqual(ARTKey &key, idx_t max_count, vector &result_ids) { + auto temp_depth = depth + 1; + reference ref(*child); - auto leaf = Lookup(tree, key, 0); - if (!leaf) { - return true; - } - return Leaf::GetRowIds(*this, *leaf, result_ids, max_count); -} + if (ref.get().GetType() == NType::PREFIX) { + Prefix::TraverseMutable(*this, ref, key, temp_depth); -void ART::SearchEqualJoinNoFetch(ARTKey &key, idx_t &result_size) { + // Prefixes don't match: nothing to erase. + if (ref.get().GetType() == NType::PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { + return; + } + } - // we need to look for a leaf - auto leaf_node = Lookup(tree, key, 0); - if (!leaf_node) { - result_size = 0; + if (ref.get().GetType() == NType::LEAF_INLINED) { + if (ref.get().GetRowId() == row_id.get().GetRowId()) { + Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); + } return; } - // we only perform index joins on PK/FK columns - D_ASSERT(leaf_node->GetType() == NType::LEAF_INLINED); - result_size = 1; - return; + // Recurse. + Erase(*child, key, depth + 1, row_id, status); + if (!child->HasMetadata()) { + Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); + } else { + next.get().ReplaceChild(*this, key.get()[depth], *child); + } } //===--------------------------------------------------------------------===// -// Lookup +// Point and range lookups //===--------------------------------------------------------------------===// -optional_ptr ART::Lookup(const Node &node, const ARTKey &key, idx_t depth) { +const unsafe_optional_ptr ART::Lookup(const Node &node, const ARTKey &key, idx_t depth) { + reference ref(node); + while (ref.get().HasMetadata()) { - reference node_ref(node); - while (node_ref.get().HasMetadata()) { + // Return the leaf. + if (ref.get().IsAnyLeaf() || ref.get().GetGateStatus() == GateStatus::GATE_SET) { + return unsafe_optional_ptr(ref.get()); + } - // traverse prefix, if exists - reference next_node(node_ref.get()); - if (next_node.get().GetType() == NType::PREFIX) { - Prefix::Traverse(*this, next_node, key, depth); - if (next_node.get().GetType() == NType::PREFIX) { + // Traverse the prefix. + if (ref.get().GetType() == NType::PREFIX) { + Prefix::Traverse(*this, ref, key, depth); + if (ref.get().GetType() == NType::PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { + // Prefix mismatch, return nullptr. return nullptr; } + continue; } - if (next_node.get().GetType() == NType::LEAF || next_node.get().GetType() == NType::LEAF_INLINED) { - return &next_node.get(); - } - + // Get the child node. D_ASSERT(depth < key.len); - auto child = next_node.get().GetChild(*this, key[depth]); + auto child = ref.get().GetChild(*this, key[depth]); + + // No child at the matching byte, return nullptr. if (!child) { - // prefix matches key, but no child at byte, ART/subtree does not contain key return nullptr; } - // lookup in child node - node_ref = *child; - D_ASSERT(node_ref.get().HasMetadata()); + // Continue in the child. + ref = *child; + D_ASSERT(ref.get().HasMetadata()); depth++; } return nullptr; } -//===--------------------------------------------------------------------===// -// Greater Than and Less Than -//===--------------------------------------------------------------------===// +bool ART::SearchEqual(ARTKey &key, idx_t max_count, unsafe_vector &row_ids) { + auto leaf = Lookup(tree, key, 0); + if (!leaf) { + return true; + } -bool ART::SearchGreater(ARTIndexScanState &state, ARTKey &key, bool equal, idx_t max_count, vector &result_ids) { + Iterator it(*this); + it.FindMinimum(*leaf); + ARTKey empty_key = ARTKey(); + return it.Scan(empty_key, max_count, row_ids, false); +} +bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, unsafe_vector &row_ids) { if (!tree.HasMetadata()) { return true; } - Iterator &it = state.iterator; - // find the lowest value that satisfies the predicate - if (!it.art) { - it.art = this; - if (!it.LowerBound(tree, key, equal, 0)) { - // early-out, if the maximum value in the ART is lower than the lower bound - return true; - } + // Find the lowest value that satisfies the predicate. + Iterator it(*this); + + // Early-out, if the maximum value in the ART is lower than the lower bound. + if (!it.LowerBound(tree, key, equal, 0)) { + return true; } - // after that we continue the scan; we don't need to check the bounds as any value following this value is - // automatically bigger and hence satisfies our predicate - ARTKey empty_key = ARTKey(); - return it.Scan(empty_key, max_count, result_ids, false); + // We continue the scan. We do not check the bounds as any value following this value is + // greater and satisfies our predicate. + return it.Scan(ARTKey(), max_count, row_ids, false); } -bool ART::SearchLess(ARTIndexScanState &state, ARTKey &upper_bound, bool equal, idx_t max_count, - vector &result_ids) { - +bool ART::SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, unsafe_vector &row_ids) { if (!tree.HasMetadata()) { return true; } - Iterator &it = state.iterator; - if (!it.art) { - it.art = this; - // find the minimum value in the ART: we start scanning from this value - it.FindMinimum(tree); - // early-out, if the minimum value is higher than the upper bound - if (it.current_key > upper_bound) { - return true; - } + // Find the minimum value in the ART: we start scanning from this value. + Iterator it(*this); + it.FindMinimum(tree); + + // Early-out, if the minimum value is higher than the upper bound. + if (it.current_key.GreaterThan(upper_bound, equal, it.GetNestedDepth())) { + return true; } - // now continue the scan until we reach the upper bound - return it.Scan(upper_bound, max_count, result_ids, equal); + // Continue the scan until we reach the upper bound. + return it.Scan(upper_bound, max_count, row_ids, equal); } -//===--------------------------------------------------------------------===// -// Closed Range Query -//===--------------------------------------------------------------------===// - -bool ART::SearchCloseRange(ARTIndexScanState &state, ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, - bool right_equal, idx_t max_count, vector &result_ids) { +bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, bool right_equal, idx_t max_count, + unsafe_vector &row_ids) { + // Find the first node that satisfies the left predicate. + Iterator it(*this); - Iterator &it = state.iterator; - - // find the first node that satisfies the left predicate - if (!it.art) { - it.art = this; - if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { - // early-out, if the maximum value in the ART is lower than the lower bound - return true; - } + // Early-out, if the maximum value in the ART is lower than the lower bound. + if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { + return true; } - // now continue the scan until we reach the upper bound - return it.Scan(upper_bound, max_count, result_ids, right_equal); + // Continue the scan until we reach the upper bound. + return it.Scan(upper_bound, max_count, row_ids, right_equal); } -bool ART::Scan(const Transaction &transaction, const DataTable &table, IndexScanState &state, const idx_t max_count, - vector &result_ids) { - +bool ART::Scan(IndexScanState &state, const idx_t max_count, unsafe_vector &row_ids) { auto &scan_state = state.Cast(); - vector row_ids; - bool success; - - // FIXME: the key directly owning the data for a single key might be more efficient D_ASSERT(scan_state.values[0].type().InternalType() == types[0]); ArenaAllocator arena_allocator(Allocator::Get(db)); - auto key = CreateKey(arena_allocator, types[0], scan_state.values[0]); + auto key = ARTKey::CreateKey(arena_allocator, types[0], scan_state.values[0]); if (scan_state.values[1].IsNull()) { - - // single predicate + // Single predicate. lock_guard l(lock); switch (scan_state.expressions[0]) { case ExpressionType::COMPARE_EQUAL: - success = SearchEqual(key, max_count, row_ids); - break; + return SearchEqual(key, max_count, row_ids); case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - success = SearchGreater(scan_state, key, true, max_count, row_ids); - break; + return SearchGreater(key, true, max_count, row_ids); case ExpressionType::COMPARE_GREATERTHAN: - success = SearchGreater(scan_state, key, false, max_count, row_ids); - break; + return SearchGreater(key, false, max_count, row_ids); case ExpressionType::COMPARE_LESSTHANOREQUALTO: - success = SearchLess(scan_state, key, true, max_count, row_ids); - break; + return SearchLess(key, true, max_count, row_ids); case ExpressionType::COMPARE_LESSTHAN: - success = SearchLess(scan_state, key, false, max_count, row_ids); - break; + return SearchLess(key, false, max_count, row_ids); default: throw InternalException("Index scan type not implemented"); } - - } else { - - // two predicates - lock_guard l(lock); - - D_ASSERT(scan_state.values[1].type().InternalType() == types[0]); - auto upper_bound = CreateKey(arena_allocator, types[0], scan_state.values[1]); - - bool left_equal = scan_state.expressions[0] == ExpressionType ::COMPARE_GREATERTHANOREQUALTO; - bool right_equal = scan_state.expressions[1] == ExpressionType ::COMPARE_LESSTHANOREQUALTO; - success = SearchCloseRange(scan_state, key, upper_bound, left_equal, right_equal, max_count, row_ids); - } - - if (!success) { - return false; - } - if (row_ids.empty()) { - return true; } - // sort the row ids - sort(row_ids.begin(), row_ids.end()); - // duplicate eliminate the row ids and append them to the row ids of the state - result_ids.reserve(row_ids.size()); - - result_ids.push_back(row_ids[0]); - for (idx_t i = 1; i < row_ids.size(); i++) { - if (row_ids[i] != row_ids[i - 1]) { - result_ids.push_back(row_ids[i]); - } - } - return true; + // Two predicates. + lock_guard l(lock); + D_ASSERT(scan_state.values[1].type().InternalType() == types[0]); + auto upper_bound = ARTKey::CreateKey(arena_allocator, types[0], scan_state.values[1]); + bool left_equal = scan_state.expressions[0] == ExpressionType ::COMPARE_GREATERTHANOREQUALTO; + bool right_equal = scan_state.expressions[1] == ExpressionType ::COMPARE_LESSTHANOREQUALTO; + return SearchCloseRange(key, upper_bound, left_equal, right_equal, max_count, row_ids); } //===--------------------------------------------------------------------===// -// More Verification / Constraint Checking +// More Constraint Checking //===--------------------------------------------------------------------===// -string ART::GenerateErrorKeyName(DataChunk &input, idx_t row) { - - // FIXME: why exactly can we not pass the expression_chunk as an argument to this - // FIXME: function instead of re-executing? - // re-executing the expressions is not very fast, but we're going to throw, so we don't care - DataChunk expression_chunk; - expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expression_chunk); +string ART::GenerateErrorKeyName(DataChunk &input, idx_t row_idx) { + DataChunk expr_chunk; + expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expr_chunk); string key_name; - for (idx_t k = 0; k < expression_chunk.ColumnCount(); k++) { + for (idx_t k = 0; k < expr_chunk.ColumnCount(); k++) { if (k > 0) { key_name += ", "; } - key_name += unbound_expressions[k]->GetName() + ": " + expression_chunk.data[k].GetValue(row).ToString(); + key_name += unbound_expressions[k]->GetName() + ": " + expr_chunk.data[k].GetValue(row_idx).ToString(); } return key_name; } @@ -1006,7 +951,7 @@ string ART::GenerateErrorKeyName(DataChunk &input, idx_t row) { string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name) { switch (verify_type) { case VerifyExistenceType::APPEND: { - // APPEND to PK/UNIQUE table, but node/key already exists in PK/UNIQUE table + // APPEND to PK/UNIQUE table, but node/key already exists in PK/UNIQUE table. string type = IsPrimary() ? "primary key" : "unique"; return StringUtil::Format("Duplicate key \"%s\" violates %s constraint. " "If this is an unexpected constraint violation please double " @@ -1015,12 +960,12 @@ string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, cons key_name, type); } case VerifyExistenceType::APPEND_FK: { - // APPEND_FK to FK table, node/key does not exist in PK/UNIQUE table + // APPEND_FK to FK table, node/key does not exist in PK/UNIQUE table. return StringUtil::Format( "Violates foreign key constraint because key \"%s\" does not exist in the referenced table", key_name); } case VerifyExistenceType::DELETE_FK: { - // DELETE_FK that still exists in a FK table, i.e., not a valid delete + // DELETE_FK that still exists in a FK table, i.e., not a valid delete. return StringUtil::Format("Violates foreign key constraint because key \"%s\" is still referenced by a foreign " "key in a different table", key_name); @@ -1031,23 +976,19 @@ string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, cons } void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_manager) { - - // don't alter the index during constraint checking + // Lock the index during constraint checking. lock_guard l(lock); - // first resolve the expressions for the index - DataChunk expression_chunk; - expression_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expression_chunk); + DataChunk expr_chunk; + expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); + ExecuteExpressions(input, expr_chunk); - // generate the keys for the given input ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - vector keys(expression_chunk.size()); - GenerateKeys(arena_allocator, expression_chunk, keys); + unsafe_vector keys(expr_chunk.size()); + GenerateKeys<>(arena_allocator, expr_chunk, keys); - idx_t found_conflict = DConstants::INVALID_INDEX; + auto found_conflict = DConstants::INVALID_INDEX; for (idx_t i = 0; found_conflict == DConstants::INVALID_INDEX && i < input.size(); i++) { - if (keys[i].Empty()) { if (conflict_manager.AddNull(i)) { found_conflict = i; @@ -1063,8 +1004,8 @@ void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_m continue; } - // when we find a node, we need to update the 'matches' and 'row_ids' - // NOTE: leaves can have more than one row_id, but for UNIQUE/PRIMARY KEY they will only have one + // If we find a node, we need to update the 'matches' and 'row_ids'. + // We only perform constraint checking on unique indexes, i.e., all leaves are inlined. D_ASSERT(leaf->GetType() == NType::LEAF_INLINED); if (conflict_manager.AddHit(i, leaf->GetRowId())) { found_conflict = i; @@ -1072,7 +1013,6 @@ void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_m } conflict_manager.FinishLookup(); - if (found_conflict == DConstants::INVALID_INDEX) { return; } @@ -1082,93 +1022,178 @@ void ART::CheckConstraintsForChunk(DataChunk &input, ConflictManager &conflict_m throw ConstraintException(exception_msg); } +string ART::GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) { + auto key_name = GenerateErrorKeyName(input, failed_index); + auto exception_msg = GenerateConstraintErrorMessage(verify_type, key_name); + return exception_msg; +} + //===--------------------------------------------------------------------===// -// Helper functions for (de)serialization +// Storage and Memory //===--------------------------------------------------------------------===// -IndexStorageInfo ART::GetStorageInfo(const bool get_buffers) { +void ART::TransformToDeprecated() { + auto idx = Node::GetAllocatorIdx(NType::PREFIX); + auto &block_manager = (*allocators)[idx]->block_manager; + unsafe_unique_ptr deprecated_allocator; + + if (prefix_count != Prefix::DEPRECATED_COUNT) { + auto prefix_size = NumericCast(Prefix::DEPRECATED_COUNT) + NumericCast(Prefix::METADATA_SIZE); + deprecated_allocator = make_unsafe_uniq(prefix_size, block_manager); + } + + // Transform all leaves, and possibly the prefixes. + if (tree.HasMetadata()) { + Node::TransformToDeprecated(*this, tree, deprecated_allocator); + } + + // Replace the prefix allocator with the deprecated allocator. + if (deprecated_allocator) { + prefix_count = Prefix::DEPRECATED_COUNT; + + D_ASSERT((*allocators)[idx]->IsEmpty()); + (*allocators)[idx]->Reset(); + (*allocators)[idx] = std::move(deprecated_allocator); + } +} + +IndexStorageInfo ART::GetStorageInfo(const case_insensitive_map_t &options, const bool to_wal) { + // If the storage format uses deprecated leaf storage, + // then we need to transform all nested leaves before serialization. + auto v1_0_0_option = options.find("v1_0_0_storage"); + bool v1_0_0_storage = v1_0_0_option == options.end() || v1_0_0_option->second != Value(false); + if (v1_0_0_storage) { + TransformToDeprecated(); + } - // set the name and root node - IndexStorageInfo info; - info.name = name; + IndexStorageInfo info(name); info.root = tree.Get(); + info.options = options; - if (!get_buffers) { - // store the data on disk as partial blocks and set the block ids - WritePartialBlocks(); + for (auto &allocator : *allocators) { + allocator->RemoveEmptyBuffers(); + } + +#ifdef DEBUG + if (v1_0_0_storage) { + D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_7_LEAF)]->IsEmpty()); + D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_15_LEAF)]->IsEmpty()); + D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_256_LEAF)]->IsEmpty()); + D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::PREFIX)]->GetSegmentSize() == + Prefix::DEPRECATED_COUNT + Prefix::METADATA_SIZE); + } +#endif + + auto allocator_count = v1_0_0_storage ? DEPRECATED_ALLOCATOR_COUNT : ALLOCATOR_COUNT; + if (!to_wal) { + // Store the data on disk as partial blocks and set the block ids. + WritePartialBlocks(v1_0_0_storage); } else { - // set the correct allocation sizes and get the map containing all buffers - for (const auto &allocator : *allocators) { - info.buffers.push_back(allocator->InitSerializationToWAL()); + // Set the correct allocation sizes and get the map containing all buffers. + for (idx_t i = 0; i < allocator_count; i++) { + info.buffers.push_back((*allocators)[i]->InitSerializationToWAL()); } } - for (const auto &allocator : *allocators) { - info.allocator_infos.push_back(allocator->GetInfo()); + for (idx_t i = 0; i < allocator_count; i++) { + info.allocator_infos.push_back((*allocators)[i]->GetInfo()); } - return info; } -void ART::WritePartialBlocks() { - - // use the partial block manager to serialize all allocator data +void ART::WritePartialBlocks(const bool v1_0_0_storage) { auto &block_manager = table_io_manager.GetIndexBlockManager(); PartialBlockManager partial_block_manager(block_manager, PartialBlockType::FULL_CHECKPOINT); - for (auto &allocator : *allocators) { - allocator->SerializeBuffers(partial_block_manager); + idx_t allocator_count = v1_0_0_storage ? DEPRECATED_ALLOCATOR_COUNT : ALLOCATOR_COUNT; + for (idx_t i = 0; i < allocator_count; i++) { + (*allocators)[i]->SerializeBuffers(partial_block_manager); } partial_block_manager.FlushPartialBlocks(); } void ART::InitAllocators(const IndexStorageInfo &info) { - - // set the root node - tree.Set(info.root); - - // initialize the allocators - D_ASSERT(info.allocator_infos.size() == ALLOCATOR_COUNT); for (idx_t i = 0; i < info.allocator_infos.size(); i++) { (*allocators)[i]->Init(info.allocator_infos[i]); } } void ART::Deserialize(const BlockPointer &pointer) { - D_ASSERT(pointer.IsValid()); + auto &metadata_manager = table_io_manager.GetMetadataManager(); MetadataReader reader(metadata_manager, pointer); tree = reader.Read(); - for (idx_t i = 0; i < ALLOCATOR_COUNT; i++) { + for (idx_t i = 0; i < DEPRECATED_ALLOCATOR_COUNT; i++) { (*allocators)[i]->Deserialize(metadata_manager, reader.Read()); } } -//===--------------------------------------------------------------------===// -// Vacuum -//===--------------------------------------------------------------------===// +void ART::SetPrefixCount(const IndexStorageInfo &info) { + auto numeric_max = NumericLimits().Maximum(); + auto max_aligned = AlignValueFloor(numeric_max - Prefix::METADATA_SIZE); + + if (info.IsValid() && info.root_block_ptr.IsValid()) { + prefix_count = Prefix::DEPRECATED_COUNT; + return; + } -void ART::InitializeVacuum(ARTFlags &flags) { - flags.vacuum_flags.reserve(flags.vacuum_flags.size() + allocators->size()); + if (info.IsValid()) { + auto serialized_count = info.allocator_infos[0].segment_size - Prefix::METADATA_SIZE; + prefix_count = NumericCast(serialized_count); + return; + } + + if (!IsUnique()) { + prefix_count = Prefix::ROW_ID_COUNT; + return; + } + + idx_t compound_size = 0; + for (const auto &type : types) { + compound_size += GetTypeIdSize(type); + } + + auto aligned = AlignValue(compound_size) - 1; + if (aligned > NumericCast(max_aligned)) { + prefix_count = max_aligned; + return; + } + + prefix_count = NumericCast(aligned); +} + +idx_t ART::GetInMemorySize(IndexLock &index_lock) { + D_ASSERT(owns_data); + + idx_t in_memory_size = 0; for (auto &allocator : *allocators) { - flags.vacuum_flags.push_back(allocator->InitializeVacuum()); + in_memory_size += allocator->GetInMemorySize(); } + return in_memory_size; } -void ART::FinalizeVacuum(const ARTFlags &flags) { +//===--------------------------------------------------------------------===// +// Vacuum +//===--------------------------------------------------------------------===// +void ART::InitializeVacuum(unordered_set &indexes) { for (idx_t i = 0; i < allocators->size(); i++) { - if (flags.vacuum_flags[i]) { - (*allocators)[i]->FinalizeVacuum(); + if ((*allocators)[i]->InitializeVacuum()) { + indexes.insert(NumericCast(i)); } } } -void ART::Vacuum(IndexLock &state) { +void ART::FinalizeVacuum(const unordered_set &indexes) { + for (const auto &idx : indexes) { + (*allocators)[idx]->FinalizeVacuum(); + } +} +void ART::Vacuum(IndexLock &state) { D_ASSERT(owns_data); if (!tree.HasMetadata()) { @@ -1178,60 +1203,34 @@ void ART::Vacuum(IndexLock &state) { return; } - // holds true, if an allocator needs a vacuum, and false otherwise - ARTFlags flags; - InitializeVacuum(flags); + // True, if an allocator needs a vacuum, false otherwise. + unordered_set indexes; + InitializeVacuum(indexes); - // skip vacuum if no allocators require it - auto perform_vacuum = false; - for (const auto &vacuum_flag : flags.vacuum_flags) { - if (vacuum_flag) { - perform_vacuum = true; - break; - } - } - if (!perform_vacuum) { + // Skip vacuum, if no allocators require it. + if (indexes.empty()) { return; } - // traverse the allocated memory of the tree to perform a vacuum - tree.Vacuum(*this, flags); - - // finalize the vacuum operation - FinalizeVacuum(flags); -} + // Traverse the allocated memory of the tree to perform a vacuum. + tree.Vacuum(*this, indexes); -//===--------------------------------------------------------------------===// -// Size -//===--------------------------------------------------------------------===// - -idx_t ART::GetInMemorySize(IndexLock &index_lock) { - - D_ASSERT(owns_data); - - idx_t in_memory_size = 0; - for (auto &allocator : *allocators) { - in_memory_size += allocator->GetInMemorySize(); - } - return in_memory_size; + // Finalize the vacuum operation. + FinalizeVacuum(indexes); } //===--------------------------------------------------------------------===// // Merging //===--------------------------------------------------------------------===// -void ART::InitializeMerge(ARTFlags &flags) { - +void ART::InitializeMerge(unsafe_vector &upper_bounds) { D_ASSERT(owns_data); - - flags.merge_buffer_counts.reserve(allocators->size()); for (auto &allocator : *allocators) { - flags.merge_buffer_counts.emplace_back(allocator->GetUpperBoundBufferId()); + upper_bounds.emplace_back(allocator->GetUpperBoundBufferId()); } } bool ART::MergeIndexes(IndexLock &state, BoundIndex &other_index) { - auto &other_art = other_index.Cast(); if (!other_art.tree.HasMetadata()) { return true; @@ -1239,33 +1238,31 @@ bool ART::MergeIndexes(IndexLock &state, BoundIndex &other_index) { if (other_art.owns_data) { if (tree.HasMetadata()) { - // fully deserialize other_index, and traverse it to increment its buffer IDs - ARTFlags flags; - InitializeMerge(flags); - other_art.tree.InitializeMerge(other_art, flags); + // Fully deserialize other_index, and traverse it to increment its buffer IDs. + unsafe_vector upper_bounds; + InitializeMerge(upper_bounds); + other_art.tree.InitMerge(other_art, upper_bounds); } - // merge the node storage + // Merge the node storage. for (idx_t i = 0; i < allocators->size(); i++) { (*allocators)[i]->Merge(*(*other_art.allocators)[i]); } } - // merge the ARTs - if (!tree.Merge(*this, other_art.tree)) { + // Merge the ARTs. + D_ASSERT(tree.GetGateStatus() == other_art.tree.GetGateStatus()); + if (!tree.Merge(*this, other_art.tree, tree.GetGateStatus())) { return false; } return true; } //===--------------------------------------------------------------------===// -// Utility +// Verification //===--------------------------------------------------------------------===// string ART::VerifyAndToString(IndexLock &state, const bool only_verify) { - // FIXME: this can be improved by counting the allocations of each node type, - // FIXME: and by asserting that each fixed-size allocator lists an equal number of - // FIXME: allocations of that type return VerifyAndToStringInternal(only_verify); } @@ -1276,10 +1273,26 @@ string ART::VerifyAndToStringInternal(const bool only_verify) { return "[empty]"; } -string ART::GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) { - auto key_name = GenerateErrorKeyName(input, failed_index); - auto exception_msg = GenerateConstraintErrorMessage(verify_type, key_name); - return exception_msg; +void ART::VerifyAllocations(IndexLock &state) { + return VerifyAllocationsInternal(); +} + +void ART::VerifyAllocationsInternal() { +#ifdef DEBUG + unordered_map node_counts; + for (idx_t i = 0; i < allocators->size(); i++) { + node_counts[NumericCast(i)] = 0; + } + + if (tree.HasMetadata()) { + tree.VerifyAllocations(*this, node_counts); + } + + for (idx_t i = 0; i < allocators->size(); i++) { + auto segment_count = (*allocators)[i]->GetSegmentCount(); + D_ASSERT(segment_count == node_counts[NumericCast(i)]); + } +#endif } constexpr const char *ART::TYPE_NAME; diff --git a/src/duckdb/src/execution/index/art/art_key.cpp b/src/duckdb/src/execution/index/art/art_key.cpp index e9be0abb..d5769f0f 100644 --- a/src/duckdb/src/execution/index/art/art_key.cpp +++ b/src/duckdb/src/execution/index/art/art_key.cpp @@ -2,98 +2,181 @@ namespace duckdb { +//===--------------------------------------------------------------------===// +// ARTKey +//===--------------------------------------------------------------------===// + ARTKey::ARTKey() : len(0) { } -ARTKey::ARTKey(const data_ptr_t &data, const uint32_t &len) : len(len), data(data) { +ARTKey::ARTKey(const data_ptr_t data, idx_t len) : len(len), data(data) { } -ARTKey::ARTKey(ArenaAllocator &allocator, const uint32_t &len) : len(len) { +ARTKey::ARTKey(ArenaAllocator &allocator, idx_t len) : len(len) { data = allocator.Allocate(len); } template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, string_t value) { +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, string_t value) { auto string_data = const_data_ptr_cast(value.GetData()); auto string_len = value.GetSize(); - // we need to escape \00 and \01 + + // We escape \00 and \01. idx_t escape_count = 0; - for (idx_t r = 0; r < string_len; r++) { - if (string_data[r] <= 1) { + for (idx_t i = 0; i < string_len; i++) { + if (string_data[i] <= 1) { escape_count++; } } + idx_t len = string_len + escape_count + 1; auto data = allocator.Allocate(len); - // copy over the data and add in escapes + + // Copy over the data and add escapes. idx_t pos = 0; - for (idx_t r = 0; r < string_len; r++) { - if (string_data[r] <= 1) { - // escape + for (idx_t i = 0; i < string_len; i++) { + if (string_data[i] <= 1) { + // Add escape. data[pos++] = '\01'; } - data[pos++] = string_data[r]; + data[pos++] = string_data[i]; } - // end with a null-terminator + + // End with a null-terminator. data[pos] = '\0'; - return ARTKey(data, UnsafeNumericCast(len)); + return ARTKey(data, len); } template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, const char *value) { - return ARTKey::CreateARTKey(allocator, type, string_t(value, UnsafeNumericCast(strlen(value)))); +ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const char *value) { + return ARTKey::CreateARTKey(allocator, string_t(value, UnsafeNumericCast(strlen(value)))); } template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, string_t value) { - key = ARTKey::CreateARTKey(allocator, type, value); +void ARTKey::CreateARTKey(ArenaAllocator &allocator, ARTKey &key, string_t value) { + key = ARTKey::CreateARTKey(allocator, value); } template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, const LogicalType &type, ARTKey &key, const char *value) { - ARTKey::CreateARTKey(allocator, type, key, string_t(value, UnsafeNumericCast(strlen(value)))); +void ARTKey::CreateARTKey(ArenaAllocator &allocator, ARTKey &key, const char *value) { + ARTKey::CreateARTKey(allocator, key, string_t(value, UnsafeNumericCast(strlen(value)))); } -bool ARTKey::operator>(const ARTKey &k) const { - for (uint32_t i = 0; i < MinValue(len, k.len); i++) { - if (data[i] > k.data[i]) { +ARTKey ARTKey::CreateKey(ArenaAllocator &allocator, PhysicalType type, Value &value) { + D_ASSERT(type == value.type().InternalType()); + switch (type) { + case PhysicalType::BOOL: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::INT8: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::INT16: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::INT32: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::INT64: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::UINT8: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::UINT16: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::UINT32: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::UINT64: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::INT128: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::UINT128: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::FLOAT: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::DOUBLE: + return ARTKey::CreateARTKey(allocator, value); + case PhysicalType::VARCHAR: + return ARTKey::CreateARTKey(allocator, value); + default: + throw InternalException("Invalid type for the ART key."); + } +} + +bool ARTKey::operator>(const ARTKey &key) const { + for (idx_t i = 0; i < MinValue(len, key.len); i++) { + if (data[i] > key.data[i]) { return true; - } else if (data[i] < k.data[i]) { + } else if (data[i] < key.data[i]) { return false; } } - return len > k.len; + return len > key.len; } -bool ARTKey::operator>=(const ARTKey &k) const { - for (uint32_t i = 0; i < MinValue(len, k.len); i++) { - if (data[i] > k.data[i]) { +bool ARTKey::operator>=(const ARTKey &key) const { + for (idx_t i = 0; i < MinValue(len, key.len); i++) { + if (data[i] > key.data[i]) { return true; - } else if (data[i] < k.data[i]) { + } else if (data[i] < key.data[i]) { return false; } } - return len >= k.len; + return len >= key.len; } -bool ARTKey::operator==(const ARTKey &k) const { - if (len != k.len) { +bool ARTKey::operator==(const ARTKey &key) const { + if (len != key.len) { return false; } - for (uint32_t i = 0; i < len; i++) { - if (data[i] != k.data[i]) { + for (idx_t i = 0; i < len; i++) { + if (data[i] != key.data[i]) { return false; } } return true; } -void ARTKey::ConcatenateARTKey(ArenaAllocator &allocator, ARTKey &other_key) { - - auto compound_data = allocator.Allocate(len + other_key.len); +void ARTKey::Concat(ArenaAllocator &allocator, const ARTKey &other) { + auto compound_data = allocator.Allocate(len + other.len); memcpy(compound_data, data, len); - memcpy(compound_data + len, other_key.data, other_key.len); - len += other_key.len; + memcpy(compound_data + len, other.data, other.len); + len += other.len; data = compound_data; } + +row_t ARTKey::GetRowId() const { + D_ASSERT(len == sizeof(row_t)); + return Radix::DecodeData(data); +} + +idx_t ARTKey::GetMismatchPos(const ARTKey &other, const idx_t start) const { + D_ASSERT(len <= other.len); + D_ASSERT(start <= len); + for (idx_t i = start; i < other.len; i++) { + if (data[i] != other.data[i]) { + return i; + } + } + return DConstants::INVALID_INDEX; +} + +//===--------------------------------------------------------------------===// +// ARTKeySection +//===--------------------------------------------------------------------===// + +ARTKeySection::ARTKeySection(idx_t start, idx_t end, idx_t depth, data_t byte) + : start(start), end(end), depth(depth), key_byte(byte) { +} + +ARTKeySection::ARTKeySection(idx_t start, idx_t end, const unsafe_vector &keys, const ARTKeySection §ion) + : start(start), end(end), depth(section.depth + 1), key_byte(keys[end].data[section.depth]) { +} + +void ARTKeySection::GetChildSections(unsafe_vector §ions, const unsafe_vector &keys) { + auto child_idx = start; + for (idx_t i = start + 1; i <= end; i++) { + if (keys[i - 1].data[depth] != keys[i].data[depth]) { + sections.emplace_back(child_idx, i - 1, keys, *this); + child_idx = i; + } + } + sections.emplace_back(child_idx, end, keys, *this); +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp new file mode 100644 index 00000000..59492bb4 --- /dev/null +++ b/src/duckdb/src/execution/index/art/base_leaf.cpp @@ -0,0 +1,168 @@ +#include "duckdb/execution/index/art/base_leaf.hpp" + +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/base_node.hpp" +#include "duckdb/execution/index/art/leaf.hpp" +#include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/execution/index/art/node256_leaf.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// BaseLeaf +//===--------------------------------------------------------------------===// + +template +void BaseLeaf::InsertByteInternal(BaseLeaf &n, const uint8_t byte) { + // Still space. Insert the child. + uint8_t child_pos = 0; + while (child_pos < n.count && n.key[child_pos] < byte) { + child_pos++; + } + + // Move children backwards to make space. + for (uint8_t i = n.count; i > child_pos; i--) { + n.key[i] = n.key[i - 1]; + } + + n.key[child_pos] = byte; + n.count++; +} + +template +BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, Node &node, const uint8_t byte) { + auto &n = Node::Ref(art, node, node.GetType()); + uint8_t child_pos = 0; + + for (; child_pos < n.count; child_pos++) { + if (n.key[child_pos] == byte) { + break; + } + } + n.count--; + + // Possibly move children backwards. + for (uint8_t i = child_pos; i < n.count; i++) { + n.key[i] = n.key[i + 1]; + } + return n; +} + +//===--------------------------------------------------------------------===// +// Node7Leaf +//===--------------------------------------------------------------------===// + +void Node7Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { + // The node is full. Grow to Node15. + auto &n7 = Node::Ref(art, node, NODE_7_LEAF); + if (n7.count == CAPACITY) { + auto node7 = node; + Node15Leaf::GrowNode7Leaf(art, node, node7); + Node15Leaf::InsertByte(art, node, byte); + return; + } + + // Still space. Insert the child. + uint8_t child_pos = 0; + while (child_pos < n7.count && n7.key[child_pos] < byte) { + child_pos++; + } + + InsertByteInternal(n7, byte); +} + +void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byte, const ARTKey &row_id) { + auto &n7 = DeleteByteInternal(art, node, byte); + + // Compress one-way nodes. + if (n7.count == 1) { + D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); + + // Get the remaining row ID. + auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; + remainder |= UnsafeNumericCast(n7.key[0]); + + n7.count--; + Node::Free(art, node); + + if (prefix.GetType() == NType::PREFIX) { + Node::Free(art, prefix); + Leaf::New(prefix, UnsafeNumericCast(remainder)); + } else { + Leaf::New(node, UnsafeNumericCast(remainder)); + } + } +} + +void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) { + auto &n7 = New(art, node7_leaf); + auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); + node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + + n7.count = n15.count; + for (uint8_t i = 0; i < n15.count; i++) { + n7.key[i] = n15.key[i]; + } + + n15.count = 0; + Node::Free(art, node15_leaf); +} + +//===--------------------------------------------------------------------===// +// Node15Leaf +//===--------------------------------------------------------------------===// + +void Node15Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { + // The node is full. Grow to Node256Leaf. + auto &n15 = Node::Ref(art, node, NODE_15_LEAF); + if (n15.count == CAPACITY) { + auto node15 = node; + Node256Leaf::GrowNode15Leaf(art, node, node15); + Node256Leaf::InsertByte(art, node, byte); + return; + } + + InsertByteInternal(n15, byte); +} + +void Node15Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { + auto &n15 = DeleteByteInternal(art, node, byte); + + // Shrink node to Node7. + if (n15.count < Node7Leaf::CAPACITY) { + auto node15 = node; + Node7Leaf::ShrinkNode15Leaf(art, node, node15); + } +} + +void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { + auto &n7 = Node::Ref(art, node7_leaf, NType::NODE_7_LEAF); + auto &n15 = New(art, node15_leaf); + node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + + n15.count = n7.count; + for (uint8_t i = 0; i < n7.count; i++) { + n15.key[i] = n7.key[i]; + } + + n7.count = 0; + Node::Free(art, node7_leaf); +} + +void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { + auto &n15 = New(art, node15_leaf); + auto &n256 = Node::Ref(art, node256_leaf, NType::NODE_256_LEAF); + node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); + + ValidityMask mask(&n256.mask[0]); + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (mask.RowIsValid(i)) { + n15.key[n15.count] = UnsafeNumericCast(i); + n15.count++; + } + } + + Node::Free(art, node256_leaf); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/base_node.cpp b/src/duckdb/src/execution/index/art/base_node.cpp new file mode 100644 index 00000000..228d12e8 --- /dev/null +++ b/src/duckdb/src/execution/index/art/base_node.cpp @@ -0,0 +1,163 @@ +#include "duckdb/execution/index/art/base_node.hpp" + +#include "duckdb/execution/index/art/leaf.hpp" +#include "duckdb/execution/index/art/node48.hpp" +#include "duckdb/execution/index/art/prefix.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// BaseNode +//===--------------------------------------------------------------------===// + +template +void BaseNode::InsertChildInternal(BaseNode &n, const uint8_t byte, const Node child) { + // Still space. Insert the child. + uint8_t child_pos = 0; + while (child_pos < n.count && n.key[child_pos] < byte) { + child_pos++; + } + + // Move children backwards to make space. + for (uint8_t i = n.count; i > child_pos; i--) { + n.key[i] = n.key[i - 1]; + n.children[i] = n.children[i - 1]; + } + + n.key[child_pos] = byte; + n.children[child_pos] = child; + n.count++; +} + +template +BaseNode &BaseNode::DeleteChildInternal(ART &art, Node &node, const uint8_t byte) { + auto &n = Node::Ref(art, node, TYPE); + + uint8_t child_pos = 0; + for (; child_pos < n.count; child_pos++) { + if (n.key[child_pos] == byte) { + break; + } + } + + // Free the child and decrease the count. + Node::Free(art, n.children[child_pos]); + n.count--; + + // Possibly move children backwards. + for (uint8_t i = child_pos; i < n.count; i++) { + n.key[i] = n.key[i + 1]; + n.children[i] = n.children[i + 1]; + } + return n; +} + +//===--------------------------------------------------------------------===// +// Node4 +//===--------------------------------------------------------------------===// + +void Node4::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + // The node is full. Grow to Node16. + auto &n = Node::Ref(art, node, NODE_4); + if (n.count == CAPACITY) { + auto node4 = node; + Node16::GrowNode4(art, node, node4); + Node16::InsertChild(art, node, byte, child); + return; + } + + InsertChildInternal(n, byte, child); +} + +void Node4::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, const GateStatus status) { + auto &n = DeleteChildInternal(art, node, byte); + + // Compress one-way nodes. + if (n.count == 1) { + n.count--; + + auto child = n.children[0]; + auto remainder = n.key[0]; + auto old_status = node.GetGateStatus(); + + Node::Free(art, node); + Prefix::Concat(art, prefix, remainder, old_status, child, status); + } +} + +void Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { + auto &n4 = New(art, node4); + auto &n16 = Node::Ref(art, node16, NType::NODE_16); + node4.SetGateStatus(node16.GetGateStatus()); + + n4.count = n16.count; + for (uint8_t i = 0; i < n16.count; i++) { + n4.key[i] = n16.key[i]; + n4.children[i] = n16.children[i]; + } + + n16.count = 0; + Node::Free(art, node16); +} + +//===--------------------------------------------------------------------===// +// Node16 +//===--------------------------------------------------------------------===// + +void Node16::DeleteChild(ART &art, Node &node, const uint8_t byte) { + auto &n = DeleteChildInternal(art, node, byte); + + // Shrink node to Node4. + if (n.count < Node4::CAPACITY) { + auto node16 = node; + Node4::ShrinkNode16(art, node, node16); + } +} + +void Node16::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + // The node is full. Grow to Node48. + auto &n16 = Node::Ref(art, node, NODE_16); + if (n16.count == CAPACITY) { + auto node16 = node; + Node48::GrowNode16(art, node, node16); + Node48::InsertChild(art, node, byte, child); + return; + } + + InsertChildInternal(n16, byte, child); +} + +void Node16::GrowNode4(ART &art, Node &node16, Node &node4) { + auto &n4 = Node::Ref(art, node4, NType::NODE_4); + auto &n16 = New(art, node16); + node16.SetGateStatus(node4.GetGateStatus()); + + n16.count = n4.count; + for (uint8_t i = 0; i < n4.count; i++) { + n16.key[i] = n4.key[i]; + n16.children[i] = n4.children[i]; + } + + n4.count = 0; + Node::Free(art, node4); +} + +void Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { + auto &n16 = New(art, node16); + auto &n48 = Node::Ref(art, node48, NType::NODE_48); + node16.SetGateStatus(node48.GetGateStatus()); + + n16.count = 0; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (n48.child_index[i] != Node48::EMPTY_MARKER) { + n16.key[n16.count] = UnsafeNumericCast(i); + n16.children[n16.count] = n48.children[n48.child_index[i]]; + n16.count++; + } + } + + n48.count = 0; + Node::Free(art, node48); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp index 0d0290eb..3f1f1f4f 100644 --- a/src/duckdb/src/execution/index/art/iterator.cpp +++ b/src/duckdb/src/execution/index/art/iterator.cpp @@ -7,168 +7,215 @@ namespace duckdb { -bool IteratorKey::operator>(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { - if (key_bytes[i] > key.data[i]) { - return true; - } else if (key_bytes[i] < key.data[i]) { +//===--------------------------------------------------------------------===// +// IteratorKey +//===--------------------------------------------------------------------===// + +bool IteratorKey::Contains(const ARTKey &key) const { + if (Size() < key.len) { + return false; + } + for (idx_t i = 0; i < key.len; i++) { + if (key_bytes[i] != key.data[i]) { return false; } } - return key_bytes.size() > key.len; + return true; } -bool IteratorKey::operator>=(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(key_bytes.size(), key.len); i++) { +bool IteratorKey::GreaterThan(const ARTKey &key, const bool equal, const uint8_t nested_depth) const { + for (idx_t i = 0; i < MinValue(Size(), key.len); i++) { if (key_bytes[i] > key.data[i]) { return true; } else if (key_bytes[i] < key.data[i]) { return false; } } - return key_bytes.size() >= key.len; -} -bool IteratorKey::operator==(const ARTKey &key) const { - // NOTE: we only use this for finding the LowerBound, in which case the length - // has to be equal - D_ASSERT(key_bytes.size() == key.len); - for (idx_t i = 0; i < key_bytes.size(); i++) { - if (key_bytes[i] != key.data[i]) { - return false; - } - } - return true; + // Returns true, if current_key is greater than (or equal to) key. + D_ASSERT(Size() >= nested_depth); + auto this_len = Size() - nested_depth; + return equal ? this_len > key.len : this_len >= key.len; } -bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, vector &result_ids, const bool equal) { +//===--------------------------------------------------------------------===// +// Iterator +//===--------------------------------------------------------------------===// +bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vector &row_ids, const bool equal) { bool has_next; do { - if (!upper_bound.Empty()) { - // no more row IDs within the key bounds - if (equal) { - if (current_key > upper_bound) { - return true; + // An empty upper bound indicates that no upper bound exists. + if (!upper_bound.Empty() && status == GateStatus::GATE_NOT_SET) { + if (current_key.GreaterThan(upper_bound, equal, nested_depth)) { + return true; + } + } + + switch (last_leaf.GetType()) { + case NType::LEAF_INLINED: + if (row_ids.size() + 1 > max_count) { + return false; + } + row_ids.push_back(last_leaf.GetRowId()); + break; + case NType::LEAF: + if (!Leaf::DeprecatedGetRowIds(art, last_leaf, row_ids, max_count)) { + return false; + } + break; + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: { + uint8_t byte = 0; + while (last_leaf.GetNextByte(art, byte)) { + if (row_ids.size() + 1 > max_count) { + return false; } - } else { - if (current_key >= upper_bound) { - return true; + row_id[ROW_ID_SIZE - 1] = byte; + ARTKey key(&row_id[0], ROW_ID_SIZE); + row_ids.push_back(key.GetRowId()); + if (byte == NumericLimits::Maximum()) { + break; } + byte++; } + break; } - - // copy all row IDs of this leaf into the result IDs (if they don't exceed max_count) - if (!Leaf::GetRowIds(*art, last_leaf, result_ids, max_count)) { - return false; + default: + throw InternalException("Invalid leaf type for index scan."); } - // get the next leaf has_next = Next(); - } while (has_next); - return true; } void Iterator::FindMinimum(const Node &node) { - D_ASSERT(node.HasMetadata()); - // found the minimum - if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { + // Found the minimum. + if (node.IsAnyLeaf()) { last_leaf = node; return; } - // traverse the prefix + // We are passing a gate node. + if (node.GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + status = GateStatus::GATE_SET; + nested_depth = 0; + } + + // Traverse the prefix. if (node.GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(*art, node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + Prefix prefix(art, node); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { current_key.Push(prefix.data[i]); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = prefix.data[i]; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } } nodes.emplace(node, 0); - return FindMinimum(prefix.ptr); + return FindMinimum(*prefix.ptr); } - // go to the leftmost entry in the current node and recurse + // Go to the leftmost entry in the current node. uint8_t byte = 0; - auto next = node.GetNextChild(*art, byte); + auto next = node.GetNextChild(art, byte); D_ASSERT(next); + + // Recurse on the leftmost node. current_key.Push(byte); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = byte; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } nodes.emplace(node, byte); FindMinimum(*next); } bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { - if (!node.HasMetadata()) { return false; } - // we found the lower bound - if (node.GetType() == NType::LEAF || node.GetType() == NType::LEAF_INLINED) { - if (!equal && current_key == key) { + // We found any leaf node, or a gate. + if (node.IsAnyLeaf() || node.GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + D_ASSERT(current_key.Size() == key.len); + if (!equal && current_key.Contains(key)) { return Next(); } - last_leaf = node; + + if (node.GetGateStatus() == GateStatus::GATE_SET) { + FindMinimum(node); + } else { + last_leaf = node; + } return true; } + D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); if (node.GetType() != NType::PREFIX) { auto next_byte = key[depth]; - auto child = node.GetNextChild(*art, next_byte); + auto child = node.GetNextChild(art, next_byte); + + // The key is greater than any key in this subtree. if (!child) { - // the key is greater than any key in this subtree return Next(); } current_key.Push(next_byte); nodes.emplace(node, next_byte); + // We return the minimum because all keys are greater than the lower bound. if (next_byte > key[depth]) { - // we only need to find the minimum from here - // because all keys will be greater than the lower bound FindMinimum(*child); return true; } - // recurse into the child + // We recurse into the child. return LowerBound(*child, key, equal, depth + 1); } - // resolve the prefix - auto &prefix = Node::Ref(*art, node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + // Push back all prefix bytes. + Prefix prefix(art, node); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { current_key.Push(prefix.data[i]); } nodes.emplace(node, 0); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - // the key down to this node is less than the lower bound, the next key will be - // greater than the lower bound + // We compare the prefix bytes with the key bytes. + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + // We found a prefix byte that is less than its corresponding key byte. + // I.e., the subsequent node is lesser than the key. Thus, the next node + // is the lower bound. if (prefix.data[i] < key[depth + i]) { return Next(); } - // we only need to find the minimum from here - // because all keys will be greater than the lower bound + + // We found a prefix byte that is greater than its corresponding key byte. + // I.e., the subsequent node is greater than the key. Thus, the minimum is + // the lower bound. if (prefix.data[i] > key[depth + i]) { - FindMinimum(prefix.ptr); + FindMinimum(*prefix.ptr); return true; } } - // recurse into the child - depth += prefix.data[Node::PREFIX_SIZE]; - return LowerBound(prefix.ptr, key, equal, depth); + // The prefix matches the key. We recurse into the child. + depth += prefix.data[Prefix::Count(art)]; + return LowerBound(*prefix.ptr, key, equal, depth); } bool Iterator::Next() { - while (!nodes.empty()) { - auto &top = nodes.top(); - D_ASSERT(top.node.GetType() != NType::LEAF && top.node.GetType() != NType::LEAF_INLINED); + D_ASSERT(!top.node.IsAnyLeaf()); if (top.node.GetType() == NType::PREFIX) { PopNode(); @@ -176,20 +223,26 @@ bool Iterator::Next() { } if (top.byte == NumericLimits::Maximum()) { - // no node found: move up the tree, pop key byte of current node + // No more children of this node. + // Move up the tree by popping the key byte of the current node. PopNode(); continue; } top.byte++; - auto next_node = top.node.GetNextChild(*art, top.byte); + auto next_node = top.node.GetNextChild(art, top.byte); if (!next_node) { + // No more children of this node. + // Move up the tree by popping the key byte of the current node. PopNode(); continue; } current_key.Pop(1); current_key.Push(top.byte); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth - 1] = top.byte; + } FindMinimum(*next_node); return true; @@ -198,12 +251,30 @@ bool Iterator::Next() { } void Iterator::PopNode() { - if (nodes.top().node.GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(*art, nodes.top().node, NType::PREFIX); - auto prefix_byte_count = prefix.data[Node::PREFIX_SIZE]; - current_key.Pop(prefix_byte_count); - } else { + // We are popping a gate node. + if (nodes.top().node.GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_SET); + status = GateStatus::GATE_NOT_SET; + } + + // Pop the byte and the node. + if (nodes.top().node.GetType() != NType::PREFIX) { current_key.Pop(1); + if (status == GateStatus::GATE_SET) { + nested_depth--; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } + nodes.pop(); + return; + } + + // Pop all prefix bytes and the node. + Prefix prefix(art, nodes.top().node); + auto prefix_byte_count = prefix.data[Prefix::Count(art)]; + current_key.Pop(prefix_byte_count); + if (status == GateStatus::GATE_SET) { + nested_depth -= prefix_byte_count; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); } nodes.pop(); } diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp index 0fb8804e..4a5b346c 100644 --- a/src/duckdb/src/execution/index/art/leaf.cpp +++ b/src/duckdb/src/execution/index/art/leaf.cpp @@ -1,347 +1,243 @@ #include "duckdb/execution/index/art/leaf.hpp" + +#include "duckdb/common/types.hpp" #include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/base_leaf.hpp" +#include "duckdb/execution/index/art/base_node.hpp" +#include "duckdb/execution/index/art/iterator.hpp" #include "duckdb/execution/index/art/node.hpp" -#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/execution/index/art/prefix.hpp" namespace duckdb { void Leaf::New(Node &node, const row_t row_id) { - - // we directly inline this row ID into the node pointer D_ASSERT(row_id < MAX_ROW_ID_LOCAL); + + auto status = node.GetGateStatus(); node.Clear(); - node.SetMetadata(static_cast(NType::LEAF_INLINED)); + + node.SetMetadata(static_cast(INLINED)); node.SetRowId(row_id); + node.SetGateStatus(status); } -void Leaf::New(ART &art, reference &node, const row_t *row_ids, idx_t count) { - +void Leaf::New(ART &art, reference &node, const unsafe_vector &row_ids, const idx_t start, + const idx_t count) { D_ASSERT(count > 1); + D_ASSERT(!node.get().HasMetadata()); - idx_t copy_count = 0; - while (count) { - node.get() = Node::GetAllocator(art, NType::LEAF).New(); - node.get().SetMetadata(static_cast(NType::LEAF)); - - auto &leaf = Node::RefMutable(art, node, NType::LEAF); - - leaf.count = UnsafeNumericCast(MinValue((idx_t)Node::LEAF_SIZE, count)); - - for (idx_t i = 0; i < leaf.count; i++) { - leaf.row_ids[i] = row_ids[copy_count + i]; - } - - copy_count += leaf.count; - count -= leaf.count; - - node = leaf.ptr; - leaf.ptr.Clear(); + // We cannot recurse into the leaf during Construct(...) because row IDs are not sorted. + for (idx_t i = 0; i < count; i++) { + idx_t offset = start + i; + art.Insert(node, row_ids[offset], 0, row_ids[offset], GateStatus::GATE_SET); } + node.get().SetGateStatus(GateStatus::GATE_SET); } -Leaf &Leaf::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NType::LEAF).New(); - node.SetMetadata(static_cast(NType::LEAF)); - auto &leaf = Node::RefMutable(art, node, NType::LEAF); - - leaf.count = 0; - leaf.ptr.Clear(); - return leaf; -} - -void Leaf::Free(ART &art, Node &node) { +void Leaf::MergeInlined(ART &art, Node &l_node, Node &r_node) { + D_ASSERT(r_node.GetType() == INLINED); - Node current_node = node; - Node next_node; - while (current_node.HasMetadata()) { - next_node = Node::RefMutable(art, current_node, NType::LEAF).ptr; - Node::GetAllocator(art, NType::LEAF).Free(current_node); - current_node = next_node; - } - - node.Clear(); + ArenaAllocator arena_allocator(Allocator::Get(art.db)); + auto key = ARTKey::CreateARTKey(arena_allocator, r_node.GetRowId()); + art.Insert(l_node, key, 0, key, l_node.GetGateStatus()); + r_node.Clear(); } -void Leaf::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { - - auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::LEAF) - 1]; +void Leaf::InsertIntoInlined(ART &art, Node &node, const ARTKey &row_id, idx_t depth, const GateStatus status) { + D_ASSERT(node.GetType() == INLINED); - Node next_node = node; - node.IncreaseBufferId(merge_buffer_count); + ArenaAllocator allocator(Allocator::Get(art.db)); + auto key = ARTKey::CreateARTKey(allocator, node.GetRowId()); - while (next_node.HasMetadata()) { - auto &leaf = Node::RefMutable(art, next_node, NType::LEAF); - next_node = leaf.ptr; - if (leaf.ptr.HasMetadata()) { - leaf.ptr.IncreaseBufferId(merge_buffer_count); - } + GateStatus new_status; + if (status == GateStatus::GATE_NOT_SET || node.GetGateStatus() == GateStatus::GATE_SET) { + new_status = GateStatus::GATE_SET; + } else { + new_status = GateStatus::GATE_NOT_SET; } -} -void Leaf::Merge(ART &art, Node &l_node, Node &r_node) { - - D_ASSERT(l_node.HasMetadata() && r_node.HasMetadata()); + if (new_status == GateStatus::GATE_SET) { + depth = 0; + } + node.Clear(); - // copy inlined row ID of r_node - if (r_node.GetType() == NType::LEAF_INLINED) { - Insert(art, l_node, r_node.GetRowId()); - r_node.Clear(); - return; + // Get the mismatching position. + D_ASSERT(row_id.len == key.len); + auto pos = row_id.GetMismatchPos(key, depth); + D_ASSERT(pos != DConstants::INVALID_INDEX); + D_ASSERT(pos >= depth); + auto byte = row_id.data[pos]; + + // Create the (optional) prefix and the node. + reference next(node); + auto count = pos - depth; + if (count != 0) { + Prefix::New(art, next, row_id, depth, count); + } + if (pos == Prefix::ROW_ID_COUNT) { + Node7Leaf::New(art, next); + } else { + Node4::New(art, next); } - // l_node has an inlined row ID, swap and insert - if (l_node.GetType() == NType::LEAF_INLINED) { - auto row_id = l_node.GetRowId(); - l_node = r_node; - Insert(art, l_node, row_id); - r_node.Clear(); - return; + // Create the children. + Node row_id_node; + Leaf::New(row_id_node, row_id.GetRowId()); + Node remainder; + if (pos != Prefix::ROW_ID_COUNT) { + Leaf::New(remainder, key.GetRowId()); } - D_ASSERT(l_node.GetType() != NType::LEAF_INLINED); - D_ASSERT(r_node.GetType() != NType::LEAF_INLINED); + Node::InsertChild(art, next, key[pos], remainder); + Node::InsertChild(art, next, byte, row_id_node); + node.SetGateStatus(new_status); +} - reference l_node_ref(l_node); - reference l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); +void Leaf::TransformToNested(ART &art, Node &node) { + D_ASSERT(node.GetType() == LEAF); - // find a non-full node - while (l_leaf.get().count == Node::LEAF_SIZE) { - l_node_ref = l_leaf.get().ptr; + ArenaAllocator allocator(Allocator::Get(art.db)); + Node root = Node(); - // the last leaf is full - if (!l_leaf.get().ptr.HasMetadata()) { - break; + // Move all row IDs into the nested leaf. + reference leaf_ref(node); + while (leaf_ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, leaf_ref, LEAF); + for (uint8_t i = 0; i < leaf.count; i++) { + auto row_id = ARTKey::CreateARTKey(allocator, leaf.row_ids[i]); + art.Insert(root, row_id, 0, row_id, GateStatus::GATE_SET); } - l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); + leaf_ref = leaf.ptr; } - // store the last leaf and then append r_node - auto last_leaf_node = l_node_ref.get(); - l_node_ref.get() = r_node; - r_node.Clear(); - - // append the remaining row IDs of the last leaf node - if (last_leaf_node.HasMetadata()) { - // find the tail - l_leaf = Node::RefMutable(art, l_node_ref, NType::LEAF); - while (l_leaf.get().ptr.HasMetadata()) { - l_leaf = Node::RefMutable(art, l_leaf.get().ptr, NType::LEAF); - } - // append the row IDs - auto &last_leaf = Node::RefMutable(art, last_leaf_node, NType::LEAF); - for (idx_t i = 0; i < last_leaf.count; i++) { - l_leaf = l_leaf.get().Append(art, last_leaf.row_ids[i]); - } - Node::GetAllocator(art, NType::LEAF).Free(last_leaf_node); - } + root.SetGateStatus(GateStatus::GATE_SET); + Node::Free(art, node); + node = root; } -void Leaf::Insert(ART &art, Node &node, const row_t row_id) { +void Leaf::TransformToDeprecated(ART &art, Node &node) { + D_ASSERT(node.GetGateStatus() == GateStatus::GATE_SET || node.GetType() == LEAF); - D_ASSERT(node.HasMetadata()); - - if (node.GetType() == NType::LEAF_INLINED) { - MoveInlinedToLeaf(art, node); - Insert(art, node, row_id); + // Early-out, if we never transformed this leaf. + if (node.GetGateStatus() == GateStatus::GATE_NOT_SET) { return; } - // append to the tail - reference leaf = Node::RefMutable(art, node, NType::LEAF); - while (leaf.get().ptr.HasMetadata()) { - leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); - } - leaf.get().Append(art, row_id); -} - -bool Leaf::Remove(ART &art, reference &node, const row_t row_id) { - - D_ASSERT(node.get().HasMetadata()); + // Collect all row IDs and free the nested leaf. + unsafe_vector row_ids; + Iterator it(art); + it.FindMinimum(node); + ARTKey empty_key = ARTKey(); + it.Scan(empty_key, NumericLimits().Maximum(), row_ids, false); + Node::Free(art, node); + D_ASSERT(row_ids.size() > 1); - if (node.get().GetType() == NType::LEAF_INLINED) { - if (node.get().GetRowId() == row_id) { - return true; - } - return false; - } + // Create the deprecated leaves. + idx_t remaining = row_ids.size(); + idx_t copy_count = 0; + reference ref(node); + while (remaining) { + ref.get() = Node::GetAllocator(art, LEAF).New(); + ref.get().SetMetadata(static_cast(LEAF)); - reference leaf = Node::RefMutable(art, node, NType::LEAF); + auto &leaf = Node::Ref(art, ref, LEAF); + auto min = MinValue(UnsafeNumericCast(LEAF_SIZE), remaining); + leaf.count = UnsafeNumericCast(min); - // inline the remaining row ID - if (leaf.get().count == 2) { - if (leaf.get().row_ids[0] == row_id || leaf.get().row_ids[1] == row_id) { - auto remaining_row_id = leaf.get().row_ids[0] == row_id ? leaf.get().row_ids[1] : leaf.get().row_ids[0]; - Node::Free(art, node); - New(node, remaining_row_id); + for (uint8_t i = 0; i < leaf.count; i++) { + leaf.row_ids[i] = row_ids[copy_count + i]; } - return false; - } - - // get the last row ID (the order within a leaf does not matter) - // because we want to overwrite the row ID to remove with that one - - // go to the tail and keep track of the previous leaf node - reference prev_leaf(leaf); - while (leaf.get().ptr.HasMetadata()) { - prev_leaf = leaf; - leaf = Node::RefMutable(art, leaf.get().ptr, NType::LEAF); - } - auto last_idx = leaf.get().count; - auto last_row_id = leaf.get().row_ids[last_idx - 1]; - - // only one row ID in this leaf segment, free it - if (leaf.get().count == 1) { - Node::Free(art, prev_leaf.get().ptr); - if (last_row_id == row_id) { - return false; - } - } else { - leaf.get().count--; - } + copy_count += leaf.count; + remaining -= leaf.count; - // find the row ID and copy the last row ID to that position - while (node.get().HasMetadata()) { - leaf = Node::RefMutable(art, node, NType::LEAF); - for (idx_t i = 0; i < leaf.get().count; i++) { - if (leaf.get().row_ids[i] == row_id) { - leaf.get().row_ids[i] = last_row_id; - return false; - } - } - node = leaf.get().ptr; + ref = leaf.ptr; + leaf.ptr.Clear(); } - return false; } -idx_t Leaf::TotalCount(ART &art, const Node &node) { +//===--------------------------------------------------------------------===// +// Deprecated code paths. +//===--------------------------------------------------------------------===// - D_ASSERT(node.HasMetadata()); - if (node.GetType() == NType::LEAF_INLINED) { - return 1; - } +void Leaf::DeprecatedFree(ART &art, Node &node) { + D_ASSERT(node.GetType() == LEAF); - idx_t count = 0; - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, node_ref, NType::LEAF); - count += leaf.count; - node_ref = leaf.ptr; + Node next; + while (node.HasMetadata()) { + next = Node::Ref(art, node, LEAF).ptr; + Node::GetAllocator(art, LEAF).Free(node); + node = next; } - return count; + node.Clear(); } -bool Leaf::GetRowIds(ART &art, const Node &node, vector &result_ids, idx_t max_count) { +bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, unsafe_vector &row_ids, const idx_t max_count) { + D_ASSERT(node.GetType() == LEAF); - // adding more elements would exceed the maximum count - D_ASSERT(node.HasMetadata()); - if (result_ids.size() + TotalCount(art, node) > max_count) { - return false; - } - - if (node.GetType() == NType::LEAF_INLINED) { - // push back the inlined row ID of this leaf - result_ids.push_back(node.GetRowId()); + reference ref(node); + while (ref.get().HasMetadata()) { - } else { - // push back all the row IDs of this leaf - reference last_leaf_ref(node); - while (last_leaf_ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, last_leaf_ref, NType::LEAF); - for (idx_t i = 0; i < leaf.count; i++) { - result_ids.push_back(leaf.row_ids[i]); - } - last_leaf_ref = leaf.ptr; + auto &leaf = Node::Ref(art, ref, LEAF); + if (row_ids.size() + leaf.count > max_count) { + return false; + } + for (uint8_t i = 0; i < leaf.count; i++) { + row_ids.push_back(leaf.row_ids[i]); } + ref = leaf.ptr; } - return true; } -bool Leaf::ContainsRowId(ART &art, const Node &node, const row_t row_id) { - +void Leaf::DeprecatedVacuum(ART &art, Node &node) { D_ASSERT(node.HasMetadata()); - - if (node.GetType() == NType::LEAF_INLINED) { - return node.GetRowId() == row_id; - } - - reference ref_node(node); - while (ref_node.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref_node, NType::LEAF); - for (idx_t i = 0; i < leaf.count; i++) { - if (leaf.row_ids[i] == row_id) { - return true; - } + D_ASSERT(node.GetType() == LEAF); + + auto &allocator = Node::GetAllocator(art, LEAF); + reference ref(node); + while (ref.get().HasMetadata()) { + if (allocator.NeedsVacuum(ref)) { + ref.get() = allocator.VacuumPointer(ref); + ref.get().SetMetadata(static_cast(LEAF)); } - ref_node = leaf.ptr; + auto &leaf = Node::Ref(art, ref, LEAF); + ref = leaf.ptr; } - - return false; } -string Leaf::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { - - if (node.GetType() == NType::LEAF_INLINED) { - return only_verify ? "" : "Leaf [count: 1, row ID: " + to_string(node.GetRowId()) + "]"; - } +string Leaf::DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify) { + D_ASSERT(node.GetType() == LEAF); string str = ""; + reference ref(node); - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - - auto &leaf = Node::Ref(art, node_ref, NType::LEAF); - D_ASSERT(leaf.count <= Node::LEAF_SIZE); + while (ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, ref, LEAF); + D_ASSERT(leaf.count <= LEAF_SIZE); str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; - for (idx_t i = 0; i < leaf.count; i++) { + for (uint8_t i = 0; i < leaf.count; i++) { str += to_string(leaf.row_ids[i]) + "-"; } str += "] "; - - node_ref = leaf.ptr; + ref = leaf.ptr; } - return only_verify ? "" : str; -} -void Leaf::Vacuum(ART &art, Node &node) { - - auto &allocator = Node::GetAllocator(art, NType::LEAF); - - reference node_ref(node); - while (node_ref.get().HasMetadata()) { - if (allocator.NeedsVacuum(node_ref)) { - node_ref.get() = allocator.VacuumPointer(node_ref); - node_ref.get().SetMetadata(static_cast(NType::LEAF)); - } - auto &leaf = Node::RefMutable(art, node_ref, NType::LEAF); - node_ref = leaf.ptr; - } -} - -void Leaf::MoveInlinedToLeaf(ART &art, Node &node) { - - D_ASSERT(node.GetType() == NType::LEAF_INLINED); - auto row_id = node.GetRowId(); - auto &leaf = New(art, node); - - leaf.count = 1; - leaf.row_ids[0] = row_id; + return only_verify ? "" : str; } -Leaf &Leaf::Append(ART &art, const row_t row_id) { +void Leaf::DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const { + auto idx = Node::GetAllocatorIdx(LEAF); + node_counts[idx]++; - reference leaf(*this); - - // we need a new leaf node - if (leaf.get().count == Node::LEAF_SIZE) { - leaf = New(art, leaf.get().ptr); + reference ref(ptr); + while (ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, ref, LEAF); + node_counts[idx]++; + ref = leaf.ptr; } - - leaf.get().row_ids[leaf.get().count] = row_id; - leaf.get().count++; - return leaf.get(); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp index 5c82b748..8a39d832 100644 --- a/src/duckdb/src/execution/index/art/node.cpp +++ b/src/duckdb/src/execution/index/art/node.cpp @@ -3,25 +3,34 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/swap.hpp" #include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/base_leaf.hpp" +#include "duckdb/execution/index/art/base_node.hpp" +#include "duckdb/execution/index/art/iterator.hpp" +#include "duckdb/execution/index/art/leaf.hpp" #include "duckdb/execution/index/art/node256.hpp" +#include "duckdb/execution/index/art/node256_leaf.hpp" #include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/execution/index/art/node16.hpp" -#include "duckdb/execution/index/art/node4.hpp" -#include "duckdb/execution/index/art/leaf.hpp" #include "duckdb/execution/index/art/prefix.hpp" #include "duckdb/storage/table_io_manager.hpp" namespace duckdb { //===--------------------------------------------------------------------===// -// New / Free +// New and free //===--------------------------------------------------------------------===// -void Node::New(ART &art, Node &node, const NType type) { - - // NOTE: leaves and prefixes should not pass through this function - +void Node::New(ART &art, Node &node, NType type) { switch (type) { + case NType::NODE_7_LEAF: + Node7Leaf::New(art, node); + break; + case NType::NODE_15_LEAF: + Node15Leaf::New(art, node); + break; + case NType::NODE_256_LEAF: + Node256Leaf::New(art, node); + break; case NType::NODE_4: Node4::New(art, node); break; @@ -35,25 +44,22 @@ void Node::New(ART &art, Node &node, const NType type) { Node256::New(art, node); break; default: - throw InternalException("Invalid node type for New."); + throw InternalException("Invalid node type for New: %d.", static_cast(type)); } } void Node::Free(ART &art, Node &node) { - if (!node.HasMetadata()) { return node.Clear(); } - // free the children of the nodes + // Free the children. auto type = node.GetType(); switch (type) { case NType::PREFIX: - // iterative return Prefix::Free(art, node); case NType::LEAF: - // iterative - return Leaf::Free(art, node); + return Leaf::DeprecatedFree(art, node); case NType::NODE_4: Node4::Free(art, node); break; @@ -68,6 +74,10 @@ void Node::Free(ART &art, Node &node) { break; case NType::LEAF_INLINED: return node.Clear(); + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + break; } GetAllocator(art, type).Free(node); @@ -75,11 +85,36 @@ void Node::Free(ART &art, Node &node) { } //===--------------------------------------------------------------------===// -// Get Allocators +// Allocators //===--------------------------------------------------------------------===// FixedSizeAllocator &Node::GetAllocator(const ART &art, const NType type) { - return *(*art.allocators)[static_cast(type) - 1]; + return *(*art.allocators)[GetAllocatorIdx(type)]; +} + +uint8_t Node::GetAllocatorIdx(const NType type) { + switch (type) { + case NType::PREFIX: + return 0; + case NType::LEAF: + return 1; + case NType::NODE_4: + return 2; + case NType::NODE_16: + return 3; + case NType::NODE_48: + return 4; + case NType::NODE_256: + return 5; + case NType::NODE_7_LEAF: + return 6; + case NType::NODE_15_LEAF: + return 7; + case NType::NODE_256_LEAF: + return 8; + default: + throw InternalException("Invalid node type for GetAllocatorIdx: %d.", static_cast(type)); + } } //===--------------------------------------------------------------------===// @@ -87,24 +122,28 @@ FixedSizeAllocator &Node::GetAllocator(const ART &art, const NType type) { //===--------------------------------------------------------------------===// void Node::ReplaceChild(const ART &art, const uint8_t byte, const Node child) const { + D_ASSERT(HasMetadata()); - switch (GetType()) { + auto type = GetType(); + switch (type) { case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).ReplaceChild(byte, child); + return Node4::ReplaceChild(Ref(art, *this, type), byte, child); case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).ReplaceChild(byte, child); + return Node16::ReplaceChild(Ref(art, *this, type), byte, child); case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).ReplaceChild(byte, child); + return Ref(art, *this, type).ReplaceChild(byte, child); case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).ReplaceChild(byte, child); + return Ref(art, *this, type).ReplaceChild(byte, child); default: - throw InternalException("Invalid node type for ReplaceChild."); + throw InternalException("Invalid node type for ReplaceChild: %d.", static_cast(type)); } } void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + D_ASSERT(node.HasMetadata()); - switch (node.GetType()) { + auto type = node.GetType(); + switch (type) { case NType::NODE_4: return Node4::InsertChild(art, node, byte, child); case NType::NODE_16: @@ -113,104 +152,134 @@ void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node chil return Node48::InsertChild(art, node, byte, child); case NType::NODE_256: return Node256::InsertChild(art, node, byte, child); + case NType::NODE_7_LEAF: + return Node7Leaf::InsertByte(art, node, byte); + case NType::NODE_15_LEAF: + return Node15Leaf::InsertByte(art, node, byte); + case NType::NODE_256_LEAF: + return Node256Leaf::InsertByte(art, node, byte); default: - throw InternalException("Invalid node type for InsertChild."); + throw InternalException("Invalid node type for InsertChild: %d.", static_cast(type)); } } //===--------------------------------------------------------------------===// -// Deletes +// Delete //===--------------------------------------------------------------------===// -void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte) { +void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, const GateStatus status, + const ARTKey &row_id) { + D_ASSERT(node.HasMetadata()); - switch (node.GetType()) { + auto type = node.GetType(); + switch (type) { case NType::NODE_4: - return Node4::DeleteChild(art, node, prefix, byte); + return Node4::DeleteChild(art, node, prefix, byte, status); case NType::NODE_16: return Node16::DeleteChild(art, node, byte); case NType::NODE_48: return Node48::DeleteChild(art, node, byte); case NType::NODE_256: return Node256::DeleteChild(art, node, byte); + case NType::NODE_7_LEAF: + return Node7Leaf::DeleteByte(art, node, prefix, byte, row_id); + case NType::NODE_15_LEAF: + return Node15Leaf::DeleteByte(art, node, byte); + case NType::NODE_256_LEAF: + return Node256Leaf::DeleteByte(art, node, byte); default: - throw InternalException("Invalid node type for DeleteChild."); + throw InternalException("Invalid node type for DeleteChild: %d.", static_cast(type)); } } //===--------------------------------------------------------------------===// -// Get functions +// Get child and byte. //===--------------------------------------------------------------------===// -optional_ptr Node::GetChild(ART &art, const uint8_t byte) const { - - D_ASSERT(HasMetadata()); +template +unsafe_optional_ptr GetChildInternal(ART &art, NODE &node, const uint8_t byte) { + D_ASSERT(node.HasMetadata()); - switch (GetType()) { + auto type = node.GetType(); + switch (type) { case NType::NODE_4: - return Ref(art, *this, NType::NODE_4).GetChild(byte); + return Node4::GetChild(Node::Ref(art, node, type), byte); case NType::NODE_16: - return Ref(art, *this, NType::NODE_16).GetChild(byte); + return Node16::GetChild(Node::Ref(art, node, type), byte); case NType::NODE_48: - return Ref(art, *this, NType::NODE_48).GetChild(byte); - case NType::NODE_256: - return Ref(art, *this, NType::NODE_256).GetChild(byte); + return Node48::GetChild(Node::Ref(art, node, type), byte); + case NType::NODE_256: { + return Node256::GetChild(Node::Ref(art, node, type), byte); + } default: - throw InternalException("Invalid node type for GetChild."); + throw InternalException("Invalid node type for GetChildInternal: %d.", static_cast(type)); } } -optional_ptr Node::GetChildMutable(ART &art, const uint8_t byte) const { +const unsafe_optional_ptr Node::GetChild(ART &art, const uint8_t byte) const { + return GetChildInternal(art, *this, byte); +} - D_ASSERT(HasMetadata()); +unsafe_optional_ptr Node::GetChildMutable(ART &art, const uint8_t byte) const { + return GetChildInternal(art, *this, byte); +} - switch (GetType()) { +template +unsafe_optional_ptr GetNextChildInternal(ART &art, NODE &node, uint8_t &byte) { + D_ASSERT(node.HasMetadata()); + + auto type = node.GetType(); + switch (type) { case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).GetChildMutable(byte); + return Node4::GetNextChild(Node::Ref(art, node, type), byte); case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).GetChildMutable(byte); + return Node16::GetNextChild(Node::Ref(art, node, type), byte); case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).GetChildMutable(byte); + return Node48::GetNextChild(Node::Ref(art, node, type), byte); case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).GetChildMutable(byte); + return Node256::GetNextChild(Node::Ref(art, node, type), byte); default: - throw InternalException("Invalid node type for GetChildMutable."); + throw InternalException("Invalid node type for GetNextChildInternal: %d.", static_cast(type)); } } -optional_ptr Node::GetNextChild(ART &art, uint8_t &byte) const { +const unsafe_optional_ptr Node::GetNextChild(ART &art, uint8_t &byte) const { + return GetNextChildInternal(art, *this, byte); +} + +unsafe_optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { + return GetNextChildInternal(art, *this, byte); +} +bool Node::HasByte(ART &art, uint8_t &byte) const { D_ASSERT(HasMetadata()); - switch (GetType()) { - case NType::NODE_4: - return Ref(art, *this, NType::NODE_4).GetNextChild(byte); - case NType::NODE_16: - return Ref(art, *this, NType::NODE_16).GetNextChild(byte); - case NType::NODE_48: - return Ref(art, *this, NType::NODE_48).GetNextChild(byte); - case NType::NODE_256: - return Ref(art, *this, NType::NODE_256).GetNextChild(byte); + auto type = GetType(); + switch (type) { + case NType::NODE_7_LEAF: + return Ref(art, *this, NType::NODE_7_LEAF).HasByte(byte); + case NType::NODE_15_LEAF: + return Ref(art, *this, NType::NODE_15_LEAF).HasByte(byte); + case NType::NODE_256_LEAF: + return Ref(art, *this, NType::NODE_256_LEAF).HasByte(byte); default: - throw InternalException("Invalid node type for GetNextChild."); + throw InternalException("Invalid node type for GetNextByte: %d.", static_cast(type)); } } -optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { - +bool Node::GetNextByte(ART &art, uint8_t &byte) const { D_ASSERT(HasMetadata()); - switch (GetType()) { - case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).GetNextChildMutable(byte); - case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).GetNextChildMutable(byte); - case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).GetNextChildMutable(byte); - case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).GetNextChildMutable(byte); + auto type = GetType(); + switch (type) { + case NType::NODE_7_LEAF: + return Ref(art, *this, NType::NODE_7_LEAF).GetNextByte(byte); + case NType::NODE_15_LEAF: + return Ref(art, *this, NType::NODE_15_LEAF).GetNextByte(byte); + case NType::NODE_256_LEAF: + return Ref(art, *this, NType::NODE_256_LEAF).GetNextByte(byte); default: - throw InternalException("Invalid node type for GetNextChildMutable."); + throw InternalException("Invalid node type for GetNextByte: %d.", static_cast(type)); } } @@ -218,301 +287,478 @@ optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { // Utility //===--------------------------------------------------------------------===// -string Node::VerifyAndToString(ART &art, const bool only_verify) const { - - D_ASSERT(HasMetadata()); - - if (GetType() == NType::LEAF || GetType() == NType::LEAF_INLINED) { - auto str = Leaf::VerifyAndToString(art, *this, only_verify); - return only_verify ? "" : "\n" + str; - } - if (GetType() == NType::PREFIX) { - auto str = Prefix::VerifyAndToString(art, *this, only_verify); - return only_verify ? "" : "\n" + str; +idx_t GetCapacity(NType type) { + switch (type) { + case NType::NODE_4: + return Node4::CAPACITY; + case NType::NODE_7_LEAF: + return Node7Leaf::CAPACITY; + case NType::NODE_15_LEAF: + return Node15Leaf::CAPACITY; + case NType::NODE_16: + return Node16::CAPACITY; + case NType::NODE_48: + return Node48::CAPACITY; + case NType::NODE_256_LEAF: + return Node256::CAPACITY; + case NType::NODE_256: + return Node256::CAPACITY; + default: + throw InternalException("Invalid node type for GetCapacity: %d.", static_cast(type)); } +} - string str = "Node" + to_string(GetCapacity()) + ": ["; - uint8_t byte = 0; - auto child = GetNextChild(art, byte); - - while (child) { - str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; - if (byte == NumericLimits::Maximum()) { - break; - } - - byte++; - child = GetNextChild(art, byte); +NType Node::GetNodeType(idx_t count) { + if (count <= Node4::CAPACITY) { + return NType::NODE_4; + } else if (count <= Node16::CAPACITY) { + return NType::NODE_16; + } else if (count <= Node48::CAPACITY) { + return NType::NODE_48; } - - return only_verify ? "" : "\n" + str + "]"; + return NType::NODE_256; } -idx_t Node::GetCapacity() const { - +bool Node::IsNode() const { switch (GetType()) { case NType::NODE_4: - return NODE_4_CAPACITY; case NType::NODE_16: - return NODE_16_CAPACITY; case NType::NODE_48: - return NODE_48_CAPACITY; case NType::NODE_256: - return NODE_256_CAPACITY; + return true; default: - throw InternalException("Invalid node type for GetCapacity."); + return false; } } -NType Node::GetARTNodeTypeByCount(const idx_t count) { +bool Node::IsLeafNode() const { + switch (GetType()) { + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + return true; + default: + return false; + } +} - if (count <= NODE_4_CAPACITY) { - return NType::NODE_4; - } else if (count <= NODE_16_CAPACITY) { - return NType::NODE_16; - } else if (count <= NODE_48_CAPACITY) { - return NType::NODE_48; +bool Node::IsAnyLeaf() const { + if (IsLeafNode()) { + return true; + } + + switch (GetType()) { + case NType::LEAF_INLINED: + case NType::LEAF: + return true; + default: + return false; } - return NType::NODE_256; } //===--------------------------------------------------------------------===// -// Merging +// Merge //===--------------------------------------------------------------------===// -void Node::InitializeMerge(ART &art, const ARTFlags &flags) { - +void Node::InitMerge(ART &art, const unsafe_vector &upper_bounds) { D_ASSERT(HasMetadata()); + auto type = GetType(); - switch (GetType()) { + switch (type) { case NType::PREFIX: - // iterative - return Prefix::InitializeMerge(art, *this, flags); + return Prefix::InitializeMerge(art, *this, upper_bounds); case NType::LEAF: - // iterative - return Leaf::InitializeMerge(art, *this, flags); + throw InternalException("Failed to initialize merge due to deprecated ART storage."); case NType::NODE_4: - RefMutable(art, *this, NType::NODE_4).InitializeMerge(art, flags); + InitMergeInternal(art, Ref(art, *this, type), upper_bounds); break; case NType::NODE_16: - RefMutable(art, *this, NType::NODE_16).InitializeMerge(art, flags); + InitMergeInternal(art, Ref(art, *this, type), upper_bounds); break; case NType::NODE_48: - RefMutable(art, *this, NType::NODE_48).InitializeMerge(art, flags); + InitMergeInternal(art, Ref(art, *this, type), upper_bounds); break; case NType::NODE_256: - RefMutable(art, *this, NType::NODE_256).InitializeMerge(art, flags); + InitMergeInternal(art, Ref(art, *this, type), upper_bounds); break; case NType::LEAF_INLINED: return; + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + break; } - IncreaseBufferId(flags.merge_buffer_counts[static_cast(GetType()) - 1]); + auto idx = GetAllocatorIdx(type); + IncreaseBufferId(upper_bounds[idx]); } -bool Node::Merge(ART &art, Node &other) { +bool Node::MergeNormalNodes(ART &art, Node &l_node, Node &r_node, uint8_t &byte, const GateStatus status) { + // Merge N4, N16, N48, N256 nodes. + D_ASSERT(l_node.IsNode() && r_node.IsNode()); + D_ASSERT(l_node.GetGateStatus() == r_node.GetGateStatus()); - if (!HasMetadata()) { - *this = other; - other = Node(); - return true; + auto r_child = r_node.GetNextChildMutable(art, byte); + while (r_child) { + auto l_child = l_node.GetChildMutable(art, byte); + if (!l_child) { + Node::InsertChild(art, l_node, byte, *r_child); + r_node.ReplaceChild(art, byte); + } else { + if (!l_child->MergeInternal(art, *r_child, status)) { + return false; + } + } + + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + r_child = r_node.GetNextChildMutable(art, byte); } - return ResolvePrefixes(art, other); + Node::Free(art, r_node); + return true; } -bool MergePrefixContainsOtherPrefix(ART &art, reference &l_node, reference &r_node, - idx_t &mismatch_position) { - - // r_node's prefix contains l_node's prefix - // l_node cannot be a leaf, otherwise the key represented by l_node would be a subset of another key - // which is not possible by our construction - D_ASSERT(l_node.get().GetType() != NType::LEAF && l_node.get().GetType() != NType::LEAF_INLINED); +void Node::MergeLeafNodes(ART &art, Node &l_node, Node &r_node, uint8_t &byte) { + // Merge N7, N15, N256 leaf nodes. + D_ASSERT(l_node.IsLeafNode() && r_node.IsLeafNode()); + D_ASSERT(l_node.GetGateStatus() == GateStatus::GATE_NOT_SET); + D_ASSERT(r_node.GetGateStatus() == GateStatus::GATE_NOT_SET); - // test if the next byte (mismatch_position) in r_node (prefix) exists in l_node - auto mismatch_byte = Prefix::GetByte(art, r_node, mismatch_position); - auto child_node = l_node.get().GetChildMutable(art, mismatch_byte); + auto has_next = r_node.GetNextByte(art, byte); + while (has_next) { + // Row IDs are always unique. + Node::InsertChild(art, l_node, byte); + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + has_next = r_node.GetNextByte(art, byte); + } - // update the prefix of r_node to only consist of the bytes after mismatch_position - Prefix::Reduce(art, r_node, mismatch_position); + Node::Free(art, r_node); +} - if (!child_node) { - // insert r_node as a child of l_node at the empty position - Node::InsertChild(art, l_node, mismatch_byte, r_node); - r_node.get().Clear(); - return true; +bool Node::MergeNodes(ART &art, Node &other, GateStatus status) { + // Merge the smaller node into the bigger node. + if (GetType() < other.GetType()) { + swap(*this, other); } - // recurse - return child_node->ResolvePrefixes(art, r_node); + uint8_t byte = 0; + if (IsNode()) { + return MergeNormalNodes(art, *this, other, byte, status); + } + MergeLeafNodes(art, *this, other, byte); + return true; } -void MergePrefixesDiffer(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { +bool Node::Merge(ART &art, Node &other, const GateStatus status) { + if (HasMetadata()) { + return MergeInternal(art, other, status); + } - // create a new node and insert both nodes as children + *this = other; + other = Node(); + return true; +} - Node l_child; - auto l_byte = Prefix::GetByte(art, l_node, mismatch_position); - Prefix::Split(art, l_node, l_child, mismatch_position); - Node4::New(art, l_node); +bool Node::PrefixContainsOther(ART &art, Node &l_node, Node &r_node, const uint8_t pos, const GateStatus status) { + // r_node's prefix contains l_node's prefix. l_node must be a node with child nodes. + D_ASSERT(l_node.IsNode()); - // insert children - Node4::InsertChild(art, l_node, l_byte, l_child); - auto r_byte = Prefix::GetByte(art, r_node, mismatch_position); - Prefix::Reduce(art, r_node, mismatch_position); - Node4::InsertChild(art, l_node, r_byte, r_node); + // Check if the next byte (pos) in r_node exists in l_node. + auto byte = Prefix::GetByte(art, r_node, pos); + auto child = l_node.GetChildMutable(art, byte); - r_node.get().Clear(); + // Reduce r_node's prefix to the bytes after pos. + Prefix::Reduce(art, r_node, pos); + if (child) { + return child->MergeInternal(art, r_node, status); + } + + Node::InsertChild(art, l_node, byte, r_node); + r_node.Clear(); + return true; } -bool Node::ResolvePrefixes(ART &art, Node &other) { +void Node::MergeIntoNode4(ART &art, Node &l_node, Node &r_node, const uint8_t pos) { + Node l_child; + auto l_byte = Prefix::GetByte(art, l_node, pos); - // NOTE: we always merge into the left ART + reference ref(l_node); + auto status = Prefix::Split(art, ref, l_child, pos); + Node4::New(art, ref); + ref.get().SetGateStatus(status); - D_ASSERT(HasMetadata() && other.HasMetadata()); + Node4::InsertChild(art, ref, l_byte, l_child); - // case 1: both nodes have no prefix - if (GetType() != NType::PREFIX && other.GetType() != NType::PREFIX) { - return MergeInternal(art, other); - } + auto r_byte = Prefix::GetByte(art, r_node, pos); + Prefix::Reduce(art, r_node, pos); + Node4::InsertChild(art, ref, r_byte, r_node); + r_node.Clear(); +} +bool Node::MergePrefixes(ART &art, Node &other, const GateStatus status) { reference l_node(*this); reference r_node(other); + auto pos = DConstants::INVALID_INDEX; - idx_t mismatch_position = DConstants::INVALID_INDEX; - - // traverse prefixes if (l_node.get().GetType() == NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - - if (!Prefix::Traverse(art, l_node, r_node, mismatch_position)) { + // Traverse prefixes. Possibly change the referenced nodes. + if (!Prefix::Traverse(art, l_node, r_node, pos, status)) { return false; } - // we already recurse because the prefixes matched (so far) - if (mismatch_position == DConstants::INVALID_INDEX) { + if (pos == DConstants::INVALID_INDEX) { return true; } } else { - - // l_prefix contains r_prefix + // l_prefix contains r_prefix. if (l_node.get().GetType() == NType::PREFIX) { swap(*this, other); } - mismatch_position = 0; + pos = 0; } - D_ASSERT(mismatch_position != DConstants::INVALID_INDEX); - // case 2: one prefix contains the other prefix + D_ASSERT(pos != DConstants::INVALID_INDEX); if (l_node.get().GetType() != NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - return MergePrefixContainsOtherPrefix(art, l_node, r_node, mismatch_position); + return PrefixContainsOther(art, l_node, r_node, UnsafeNumericCast(pos), status); } - // case 3: prefixes differ at a specific byte - MergePrefixesDiffer(art, l_node, r_node, mismatch_position); + // The prefixes differ. + MergeIntoNode4(art, l_node, r_node, UnsafeNumericCast(pos)); return true; } -bool Node::MergeInternal(ART &art, Node &other) { - - D_ASSERT(HasMetadata() && other.HasMetadata()); - D_ASSERT(GetType() != NType::PREFIX && other.GetType() != NType::PREFIX); +bool Node::MergeInternal(ART &art, Node &other, const GateStatus status) { + D_ASSERT(HasMetadata()); + D_ASSERT(other.HasMetadata()); - // always try to merge the smaller node into the bigger node - // because maybe there is enough free space in the bigger node to fit the smaller one - // without too much recursion - if (GetType() < other.GetType()) { + // Merge inlined leaves. + if (GetType() == NType::LEAF_INLINED) { swap(*this, other); } - - Node empty_node; - auto &l_node = *this; - auto &r_node = other; - - if (r_node.GetType() == NType::LEAF || r_node.GetType() == NType::LEAF_INLINED) { - D_ASSERT(l_node.GetType() == NType::LEAF || l_node.GetType() == NType::LEAF_INLINED); + if (other.GetType() == NType::LEAF_INLINED) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + D_ASSERT(other.GetGateStatus() == GateStatus::GATE_SET || other.GetType() == NType::LEAF_INLINED); + D_ASSERT(GetType() == NType::LEAF_INLINED || GetGateStatus() == GateStatus::GATE_SET); if (art.IsUnique()) { return false; } + Leaf::MergeInlined(art, *this, other); + return true; + } - Leaf::Merge(art, l_node, r_node); + // Enter a gate. + if (GetGateStatus() == GateStatus::GATE_SET && status == GateStatus::GATE_NOT_SET) { + D_ASSERT(other.GetGateStatus() == GateStatus::GATE_SET); + D_ASSERT(GetType() != NType::LEAF_INLINED); + D_ASSERT(other.GetType() != NType::LEAF_INLINED); + + // Get all row IDs. + unsafe_vector row_ids; + Iterator it(art); + it.FindMinimum(other); + ARTKey empty_key = ARTKey(); + it.Scan(empty_key, NumericLimits().Maximum(), row_ids, false); + Node::Free(art, other); + D_ASSERT(row_ids.size() > 1); + + // Insert all row IDs. + ArenaAllocator allocator(Allocator::Get(art.db)); + for (idx_t i = 0; i < row_ids.size(); i++) { + auto row_id = ARTKey::CreateARTKey(allocator, row_ids[i]); + art.Insert(*this, row_id, 0, row_id, GateStatus::GATE_SET); + } return true; } - uint8_t byte = 0; - auto r_child = r_node.GetNextChildMutable(art, byte); + // Merge N4, N16, N48, N256 nodes. + if (IsNode() && other.IsNode()) { + return MergeNodes(art, other, status); + } + // Merge N7, N15, N256 leaf nodes. + if (IsLeafNode() && other.IsLeafNode()) { + D_ASSERT(status == GateStatus::GATE_SET); + return MergeNodes(art, other, status); + } - // while r_node still has children to merge - while (r_child) { - auto l_child = l_node.GetChildMutable(art, byte); - if (!l_child) { - // insert child at empty byte - InsertChild(art, l_node, byte, *r_child); - r_node.ReplaceChild(art, byte, empty_node); + // Merge prefixes. + return MergePrefixes(art, other, status); +} - } else { - // recurse - if (!l_child->ResolvePrefixes(art, *r_child)) { - return false; - } - } +//===--------------------------------------------------------------------===// +// Vacuum +//===--------------------------------------------------------------------===// - if (byte == NumericLimits::Maximum()) { - break; +void Node::Vacuum(ART &art, const unordered_set &indexes) { + D_ASSERT(HasMetadata()); + + auto type = GetType(); + switch (type) { + case NType::LEAF_INLINED: + return; + case NType::PREFIX: + return Prefix::Vacuum(art, *this, indexes); + case NType::LEAF: + if (indexes.find(GetAllocatorIdx(type)) == indexes.end()) { + return; } - byte++; - r_child = r_node.GetNextChildMutable(art, byte); + return Leaf::DeprecatedVacuum(art, *this); + default: + break; } - Free(art, r_node); - return true; + auto idx = GetAllocatorIdx(type); + auto &allocator = GetAllocator(art, type); + auto needs_vacuum = indexes.find(idx) != indexes.end() && allocator.NeedsVacuum(*this); + if (needs_vacuum) { + auto status = GetGateStatus(); + *this = allocator.VacuumPointer(*this); + SetMetadata(static_cast(type)); + SetGateStatus(status); + } + + switch (type) { + case NType::NODE_4: + return VacuumInternal(art, Ref(art, *this, type), indexes); + case NType::NODE_16: + return VacuumInternal(art, Ref(art, *this, type), indexes); + case NType::NODE_48: + return VacuumInternal(art, Ref(art, *this, type), indexes); + case NType::NODE_256: + return VacuumInternal(art, Ref(art, *this, type), indexes); + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + return; + default: + throw InternalException("Invalid node type for Vacuum: %d.", static_cast(type)); + } } //===--------------------------------------------------------------------===// -// Vacuum +// TransformToDeprecated //===--------------------------------------------------------------------===// -void Node::Vacuum(ART &art, const ARTFlags &flags) { +void Node::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { + D_ASSERT(node.HasMetadata()); - D_ASSERT(HasMetadata()); + if (node.GetGateStatus() == GateStatus::GATE_SET) { + return Leaf::TransformToDeprecated(art, node); + } + + auto type = node.GetType(); + switch (type) { + case NType::PREFIX: + return Prefix::TransformToDeprecated(art, node, allocator); + case NType::LEAF_INLINED: + return; + case NType::LEAF: + return; + case NType::NODE_4: + return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); + case NType::NODE_16: + return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); + case NType::NODE_48: + return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); + case NType::NODE_256: + return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); + default: + throw InternalException("Invalid node type for TransformToDeprecated: %d.", static_cast(type)); + } +} - auto node_type = GetType(); - auto node_type_idx = static_cast(node_type); +//===--------------------------------------------------------------------===// +// Verification +//===--------------------------------------------------------------------===// + +string Node::VerifyAndToString(ART &art, const bool only_verify) const { + D_ASSERT(HasMetadata()); - // iterative functions - if (node_type == NType::PREFIX) { - return Prefix::Vacuum(art, *this, flags); + auto type = GetType(); + switch (type) { + case NType::LEAF_INLINED: + return only_verify ? "" : "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]"; + case NType::LEAF: + return Leaf::DeprecatedVerifyAndToString(art, *this, only_verify); + case NType::PREFIX: { + auto str = Prefix::VerifyAndToString(art, *this, only_verify); + if (GetGateStatus() == GateStatus::GATE_SET) { + str = "Gate [ " + str + " ]"; + } + return only_verify ? "" : "\n" + str; } - if (node_type == NType::LEAF_INLINED) { - return; + default: + break; } - if (node_type == NType::LEAF) { - if (flags.vacuum_flags[node_type_idx - 1]) { - Leaf::Vacuum(art, *this); + + string str = "Node" + to_string(GetCapacity(type)) + ": [ "; + uint8_t byte = 0; + + if (IsLeafNode()) { + str = "Leaf " + str; + auto has_byte = GetNextByte(art, byte); + while (has_byte) { + str += to_string(byte) + "-"; + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + has_byte = GetNextByte(art, byte); + } + } else { + auto child = GetNextChild(art, byte); + while (child) { + str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + child = GetNextChild(art, byte); } - return; } - auto &allocator = GetAllocator(art, node_type); - auto needs_vacuum = flags.vacuum_flags[node_type_idx - 1] && allocator.NeedsVacuum(*this); - if (needs_vacuum) { - *this = allocator.VacuumPointer(*this); - SetMetadata(node_type_idx); + if (GetGateStatus() == GateStatus::GATE_SET) { + str = "Gate [ " + str + " ]"; } + return only_verify ? "" : "\n" + str + "]"; +} + +void Node::VerifyAllocations(ART &art, unordered_map &node_counts) const { + D_ASSERT(HasMetadata()); - // recursive functions - switch (node_type) { + auto type = GetType(); + switch (type) { + case NType::PREFIX: + return Prefix::VerifyAllocations(art, *this, node_counts); + case NType::LEAF: + return Ref(art, *this, type).DeprecatedVerifyAllocations(art, node_counts); + case NType::LEAF_INLINED: + return; case NType::NODE_4: - return RefMutable(art, *this, NType::NODE_4).Vacuum(art, flags); + VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); + break; case NType::NODE_16: - return RefMutable(art, *this, NType::NODE_16).Vacuum(art, flags); + VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); + break; case NType::NODE_48: - return RefMutable(art, *this, NType::NODE_48).Vacuum(art, flags); + VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); + break; case NType::NODE_256: - return RefMutable(art, *this, NType::NODE_256).Vacuum(art, flags); - default: - throw InternalException("Invalid node type for Vacuum."); + VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); + break; + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + break; } + + node_counts[GetAllocatorIdx(type)]++; } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256.cpp b/src/duckdb/src/execution/index/art/node256.cpp index 30182858..f08717e1 100644 --- a/src/duckdb/src/execution/index/art/node256.cpp +++ b/src/duckdb/src/execution/index/art/node256.cpp @@ -1,17 +1,16 @@ #include "duckdb/execution/index/art/node256.hpp" + #include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/common/numeric_utils.hpp" namespace duckdb { Node256 &Node256::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_256).New(); - node.SetMetadata(static_cast(NType::NODE_256)); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); + node = Node::GetAllocator(art, NODE_256).New(); + node.SetMetadata(static_cast(NODE_256)); + auto &n256 = Node::Ref(art, node, NODE_256); n256.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { + for (uint16_t i = 0; i < CAPACITY; i++) { n256.children[i].Clear(); } @@ -19,120 +18,61 @@ Node256 &Node256::New(ART &art, Node &node) { } void Node256::Free(ART &art, Node &node) { - - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - + auto &n256 = Node::Ref(art, node, NODE_256); if (!n256.count) { return; } - // free all children - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n256.children[i].HasMetadata()) { - Node::Free(art, n256.children[i]); - } - } -} - -Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { - - auto &n48 = Node::RefMutable(art, node48, NType::NODE_48); - auto &n256 = New(art, node256); - - n256.count = n48.count; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n48.child_index[i] != Node::EMPTY_MARKER) { - n256.children[i] = n48.children[n48.child_index[i]]; - } else { - n256.children[i].Clear(); - } - } - - n48.count = 0; - Node::Free(art, node48); - return n256; -} - -void Node256::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - children[i].InitializeMerge(art, flags); - } - } + Iterator(n256, [&](Node &child) { Node::Free(art, child); }); } void Node256::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - // ensure that there is no other child at the same byte - D_ASSERT(!n256.children[byte].HasMetadata()); - + auto &n256 = Node::Ref(art, node, NODE_256); n256.count++; - D_ASSERT(n256.count <= Node::NODE_256_CAPACITY); n256.children[byte] = child; } void Node256::DeleteChild(ART &art, Node &node, const uint8_t byte) { + auto &n256 = Node::Ref(art, node, NODE_256); - D_ASSERT(node.HasMetadata()); - auto &n256 = Node::RefMutable(art, node, NType::NODE_256); - - // free the child and decrease the count + // Free the child and decrease the count. Node::Free(art, n256.children[byte]); n256.count--; - // shrink node to Node48 - if (n256.count <= Node::NODE_256_SHRINK_THRESHOLD) { + // Shrink to Node48. + if (n256.count <= SHRINK_THRESHOLD) { auto node256 = node; Node48::ShrinkNode256(art, node, node256); } } -optional_ptr Node256::GetChild(const uint8_t byte) const { - if (children[byte].HasMetadata()) { - return &children[byte]; - } - return nullptr; -} +void Node256::ReplaceChild(const uint8_t byte, const Node child) { + D_ASSERT(count > SHRINK_THRESHOLD); -optional_ptr Node256::GetChildMutable(const uint8_t byte) { - if (children[byte].HasMetadata()) { - return &children[byte]; + auto status = children[byte].GetGateStatus(); + children[byte] = child; + if (status == GateStatus::GATE_SET && child.HasMetadata()) { + children[byte].SetGateStatus(status); } - return nullptr; } -optional_ptr Node256::GetNextChild(uint8_t &byte) const { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - byte = UnsafeNumericCast(i); - return &children[i]; - } - } - return nullptr; -} +Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { + auto &n48 = Node::Ref(art, node48, NType::NODE_48); + auto &n256 = New(art, node256); + node256.SetGateStatus(node48.GetGateStatus()); -optional_ptr Node256::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - byte = UnsafeNumericCast(i); - return &children[i]; + n256.count = n48.count; + for (uint16_t i = 0; i < CAPACITY; i++) { + if (n48.child_index[i] != Node48::EMPTY_MARKER) { + n256.children[i] = n48.children[n48.child_index[i]]; + } else { + n256.children[i].Clear(); } } - return nullptr; -} -void Node256::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (children[i].HasMetadata()) { - children[i].Vacuum(art, flags); - } - } + n48.count = 0; + Node::Free(art, node48); + return n256; } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256_leaf.cpp b/src/duckdb/src/execution/index/art/node256_leaf.cpp new file mode 100644 index 00000000..01067922 --- /dev/null +++ b/src/duckdb/src/execution/index/art/node256_leaf.cpp @@ -0,0 +1,71 @@ +#include "duckdb/execution/index/art/node256_leaf.hpp" + +#include "duckdb/execution/index/art/base_leaf.hpp" +#include "duckdb/execution/index/art/node48.hpp" + +namespace duckdb { + +Node256Leaf &Node256Leaf::New(ART &art, Node &node) { + node = Node::GetAllocator(art, NODE_256_LEAF).New(); + node.SetMetadata(static_cast(NODE_256_LEAF)); + auto &n256 = Node::Ref(art, node, NODE_256_LEAF); + + n256.count = 0; + ValidityMask mask(&n256.mask[0]); + mask.SetAllInvalid(CAPACITY); + return n256; +} + +void Node256Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { + auto &n256 = Node::Ref(art, node, NODE_256_LEAF); + n256.count++; + ValidityMask mask(&n256.mask[0]); + mask.SetValid(byte); +} + +void Node256Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { + auto &n256 = Node::Ref(art, node, NODE_256_LEAF); + n256.count--; + ValidityMask mask(&n256.mask[0]); + mask.SetInvalid(byte); + + // Shrink node to Node15 + if (n256.count <= Node48::SHRINK_THRESHOLD) { + auto node256 = node; + Node15Leaf::ShrinkNode256Leaf(art, node, node256); + } +} + +bool Node256Leaf::HasByte(uint8_t &byte) { + ValidityMask v_mask(&mask[0]); + return v_mask.RowIsValid(byte); +} + +bool Node256Leaf::GetNextByte(uint8_t &byte) { + ValidityMask v_mask(&mask[0]); + for (uint16_t i = byte; i < CAPACITY; i++) { + if (v_mask.RowIsValid(i)) { + byte = UnsafeNumericCast(i); + return true; + } + } + return false; +} + +Node256Leaf &Node256Leaf::GrowNode15Leaf(ART &art, Node &node256_leaf, Node &node15_leaf) { + auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); + auto &n256 = New(art, node256_leaf); + node256_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + + n256.count = n15.count; + ValidityMask mask(&n256.mask[0]); + for (uint8_t i = 0; i < n15.count; i++) { + mask.SetValid(n15.key[i]); + } + + n15.count = 0; + Node::Free(art, node15_leaf); + return n256; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node48.cpp b/src/duckdb/src/execution/index/art/node48.cpp index 2b1ba22b..f9ad0460 100644 --- a/src/duckdb/src/execution/index/art/node48.cpp +++ b/src/duckdb/src/execution/index/art/node48.cpp @@ -1,21 +1,20 @@ #include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/execution/index/art/node16.hpp" + +#include "duckdb/execution/index/art/base_node.hpp" #include "duckdb/execution/index/art/node256.hpp" -#include "duckdb/common/numeric_utils.hpp" namespace duckdb { Node48 &Node48::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::NODE_48).New(); - node.SetMetadata(static_cast(NType::NODE_48)); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + node = Node::GetAllocator(art, NODE_48).New(); + node.SetMetadata(static_cast(NODE_48)); + auto &n48 = Node::Ref(art, node, NODE_48); n48.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - n48.child_index[i] = Node::EMPTY_MARKER; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + n48.child_index[i] = EMPTY_MARKER; } - for (idx_t i = 0; i < Node::NODE_48_CAPACITY; i++) { + for (uint8_t i = 0; i < CAPACITY; i++) { n48.children[i].Clear(); } @@ -23,39 +22,79 @@ Node48 &Node48::New(ART &art, Node &node) { } void Node48::Free(ART &art, Node &node) { + auto &n48 = Node::Ref(art, node, NODE_48); + if (!n48.count) { + return; + } - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); + Iterator(n48, [&](Node &child) { Node::Free(art, child); }); +} - if (!n48.count) { +void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { + auto &n48 = Node::Ref(art, node, NODE_48); + + // The node is full. Grow to Node256. + if (n48.count == CAPACITY) { + auto node48 = node; + Node256::GrowNode48(art, node, node48); + Node256::InsertChild(art, node, byte, child); return; } - // free all children - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (n48.child_index[i] != Node::EMPTY_MARKER) { - Node::Free(art, n48.children[n48.child_index[i]]); + // Still space. Insert the child. + uint8_t child_pos = n48.count; + if (n48.children[child_pos].HasMetadata()) { + // Find an empty position in the node list. + child_pos = 0; + while (n48.children[child_pos].HasMetadata()) { + child_pos++; } } + + n48.children[child_pos] = child; + n48.child_index[byte] = child_pos; + n48.count++; } -Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { +void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { + auto &n48 = Node::Ref(art, node, NODE_48); + + // Free the child and decrease the count. + Node::Free(art, n48.children[n48.child_index[byte]]); + n48.child_index[byte] = EMPTY_MARKER; + n48.count--; + + // Shrink to Node16. + if (n48.count < SHRINK_THRESHOLD) { + auto node48 = node; + Node16::ShrinkNode48(art, node, node48); + } +} + +void Node48::ReplaceChild(const uint8_t byte, const Node child) { + D_ASSERT(count >= SHRINK_THRESHOLD); + + auto status = children[child_index[byte]].GetGateStatus(); + children[child_index[byte]] = child; + if (status == GateStatus::GATE_SET && child.HasMetadata()) { + children[child_index[byte]].SetGateStatus(status); + } +} - auto &n16 = Node::RefMutable(art, node16, NType::NODE_16); +Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { + auto &n16 = Node::Ref(art, node16, NType::NODE_16); auto &n48 = New(art, node48); + node48.SetGateStatus(node16.GetGateStatus()); n48.count = n16.count; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - n48.child_index[i] = Node::EMPTY_MARKER; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + n48.child_index[i] = EMPTY_MARKER; } - - for (idx_t i = 0; i < n16.count; i++) { - n48.child_index[n16.key[i]] = UnsafeNumericCast(i); + for (uint8_t i = 0; i < n16.count; i++) { + n48.child_index[n16.key[i]] = i; n48.children[i] = n16.children[i]; } - - // necessary for faster child insertion/deletion - for (idx_t i = n16.count; i < Node::NODE_48_CAPACITY; i++) { + for (uint8_t i = n16.count; i < CAPACITY; i++) { n48.children[i].Clear(); } @@ -65,24 +104,21 @@ Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { } Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { - auto &n48 = New(art, node48); - auto &n256 = Node::RefMutable(art, node256, NType::NODE_256); + auto &n256 = Node::Ref(art, node256, NType::NODE_256); + node48.SetGateStatus(node256.GetGateStatus()); n48.count = 0; - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - D_ASSERT(n48.count <= Node::NODE_48_CAPACITY); - if (n256.children[i].HasMetadata()) { - n48.child_index[i] = n48.count; - n48.children[n48.count] = n256.children[i]; - n48.count++; - } else { - n48.child_index[i] = Node::EMPTY_MARKER; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (!n256.children[i].HasMetadata()) { + n48.child_index[i] = EMPTY_MARKER; + continue; } + n48.child_index[i] = n48.count; + n48.children[n48.count] = n256.children[i]; + n48.count++; } - - // necessary for faster child insertion/deletion - for (idx_t i = n48.count; i < Node::NODE_48_CAPACITY; i++) { + for (uint8_t i = n48.count; i < CAPACITY; i++) { n48.children[i].Clear(); } @@ -91,108 +127,4 @@ Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { return n48; } -void Node48::InitializeMerge(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - children[child_index[i]].InitializeMerge(art, flags); - } - } -} - -void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - // ensure that there is no other child at the same byte - D_ASSERT(n48.child_index[byte] == Node::EMPTY_MARKER); - - // insert new child node into node - if (n48.count < Node::NODE_48_CAPACITY) { - // still space, just insert the child - idx_t child_pos = n48.count; - if (n48.children[child_pos].HasMetadata()) { - // find an empty position in the node list if the current position is occupied - child_pos = 0; - while (n48.children[child_pos].HasMetadata()) { - child_pos++; - } - } - n48.children[child_pos] = child; - n48.child_index[byte] = UnsafeNumericCast(child_pos); - n48.count++; - - } else { - // node is full, grow to Node256 - auto node48 = node; - Node256::GrowNode48(art, node, node48); - Node256::InsertChild(art, node, byte, child); - } -} - -void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { - - D_ASSERT(node.HasMetadata()); - auto &n48 = Node::RefMutable(art, node, NType::NODE_48); - - // free the child and decrease the count - Node::Free(art, n48.children[n48.child_index[byte]]); - n48.child_index[byte] = Node::EMPTY_MARKER; - n48.count--; - - // shrink node to Node16 - if (n48.count < Node::NODE_48_SHRINK_THRESHOLD) { - auto node48 = node; - Node16::ShrinkNode48(art, node, node48); - } -} - -optional_ptr Node48::GetChild(const uint8_t byte) const { - if (child_index[byte] != Node::EMPTY_MARKER) { - D_ASSERT(children[child_index[byte]].HasMetadata()); - return &children[child_index[byte]]; - } - return nullptr; -} - -optional_ptr Node48::GetChildMutable(const uint8_t byte) { - if (child_index[byte] != Node::EMPTY_MARKER) { - D_ASSERT(children[child_index[byte]].HasMetadata()); - return &children[child_index[byte]]; - } - return nullptr; -} - -optional_ptr Node48::GetNextChild(uint8_t &byte) const { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - byte = UnsafeNumericCast(i); - D_ASSERT(children[child_index[i]].HasMetadata()); - return &children[child_index[i]]; - } - } - return nullptr; -} - -optional_ptr Node48::GetNextChildMutable(uint8_t &byte) { - for (idx_t i = byte; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - byte = UnsafeNumericCast(i); - D_ASSERT(children[child_index[i]].HasMetadata()); - return &children[child_index[i]]; - } - } - return nullptr; -} - -void Node48::Vacuum(ART &art, const ARTFlags &flags) { - - for (idx_t i = 0; i < Node::NODE_256_CAPACITY; i++) { - if (child_index[i] != Node::EMPTY_MARKER) { - children[child_index[i]].Vacuum(art, flags); - } - } -} - } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp index 7a485fd9..66904696 100644 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -1,370 +1,550 @@ #include "duckdb/execution/index/art/prefix.hpp" +#include "duckdb/common/swap.hpp" #include "duckdb/execution/index/art/art.hpp" #include "duckdb/execution/index/art/art_key.hpp" +#include "duckdb/execution/index/art/base_leaf.hpp" +#include "duckdb/execution/index/art/base_node.hpp" +#include "duckdb/execution/index/art/leaf.hpp" #include "duckdb/execution/index/art/node.hpp" -#include "duckdb/common/swap.hpp" namespace duckdb { -Prefix &Prefix::New(ART &art, Node &node) { - - node = Node::GetAllocator(art, NType::PREFIX).New(); - node.SetMetadata(static_cast(NType::PREFIX)); +Prefix::Prefix(const ART &art, const Node ptr_p, const bool is_mutable, const bool set_in_memory) { + if (!set_in_memory) { + data = Node::GetAllocator(art, PREFIX).Get(ptr_p, is_mutable); + } else { + data = Node::GetAllocator(art, PREFIX).GetIfLoaded(ptr_p); + if (!data) { + ptr = nullptr; + in_memory = false; + return; + } + } + ptr = reinterpret_cast(data + Count(art) + 1); + in_memory = true; +} - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); - prefix.data[Node::PREFIX_SIZE] = 0; - return prefix; +Prefix::Prefix(unsafe_unique_ptr &allocator, const Node ptr_p, const idx_t count) { + data = allocator->Get(ptr_p, true); + ptr = reinterpret_cast(data + count + 1); + in_memory = true; } -Prefix &Prefix::New(ART &art, Node &node, uint8_t byte, const Node &next) { +idx_t Prefix::GetMismatchWithOther(const Prefix &l_prefix, const Prefix &r_prefix, const idx_t max_count) { + for (idx_t i = 0; i < max_count; i++) { + if (l_prefix.data[i] != r_prefix.data[i]) { + return i; + } + } + return DConstants::INVALID_INDEX; +} - node = Node::GetAllocator(art, NType::PREFIX).New(); - node.SetMetadata(static_cast(NType::PREFIX)); +idx_t Prefix::GetMismatchWithKey(ART &art, const Node &node, const ARTKey &key, idx_t &depth) { + Prefix prefix(art, node); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + if (prefix.data[i] != key[depth]) { + return i; + } + depth++; + } + return DConstants::INVALID_INDEX; +} - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); - prefix.data[Node::PREFIX_SIZE] = 1; - prefix.data[0] = byte; - prefix.ptr = next; - return prefix; +uint8_t Prefix::GetByte(const ART &art, const Node &node, const uint8_t pos) { + D_ASSERT(node.GetType() == PREFIX); + Prefix prefix(art, node); + return prefix.data[pos]; } -void Prefix::New(ART &art, reference &node, const ARTKey &key, const uint32_t depth, uint32_t count) { +Prefix Prefix::NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset, + const NType type) { + node = Node::GetAllocator(art, type).New(); + node.SetMetadata(static_cast(type)); - if (count == 0) { - return; + Prefix prefix(art, node, true); + prefix.data[Count(art)] = count; + if (data) { + D_ASSERT(count); + memcpy(prefix.data, data + offset, count); } - idx_t copy_count = 0; + return prefix; +} - while (count) { - node.get() = Node::GetAllocator(art, NType::PREFIX).New(); - node.get().SetMetadata(static_cast(NType::PREFIX)); - auto &prefix = Node::RefMutable(art, node, NType::PREFIX); +void Prefix::New(ART &art, reference &ref, const ARTKey &key, const idx_t depth, idx_t count) { + idx_t offset = 0; - auto this_count = MinValue((uint32_t)Node::PREFIX_SIZE, count); - prefix.data[Node::PREFIX_SIZE] = (uint8_t)this_count; - memcpy(prefix.data, key.data + depth + copy_count, this_count); + while (count) { + auto min = MinValue(UnsafeNumericCast(Count(art)), count); + auto this_count = UnsafeNumericCast(min); + auto prefix = NewInternal(art, ref, key.data, this_count, offset + depth, PREFIX); - node = prefix.ptr; - copy_count += this_count; + ref = *prefix.ptr; + offset += this_count; count -= this_count; } } void Prefix::Free(ART &art, Node &node) { + Node next; - Node current_node = node; - Node next_node; - while (current_node.HasMetadata() && current_node.GetType() == NType::PREFIX) { - next_node = Node::RefMutable(art, current_node, NType::PREFIX).ptr; - Node::GetAllocator(art, NType::PREFIX).Free(current_node); - current_node = next_node; + while (node.HasMetadata() && node.GetType() == PREFIX) { + Prefix prefix(art, node, true); + next = *prefix.ptr; + Node::GetAllocator(art, PREFIX).Free(node); + node = next; } - Node::Free(art, current_node); + Node::Free(art, node); node.Clear(); } -void Prefix::InitializeMerge(ART &art, Node &node, const ARTFlags &flags) { - - auto merge_buffer_count = flags.merge_buffer_counts[static_cast(NType::PREFIX) - 1]; - - Node next_node = node; - reference prefix = Node::RefMutable(art, next_node, NType::PREFIX); +void Prefix::InitializeMerge(ART &art, Node &node, const unsafe_vector &upper_bounds) { + auto buffer_count = upper_bounds[Node::GetAllocatorIdx(PREFIX)]; + Node next = node; + Prefix prefix(art, next, true); - while (next_node.GetType() == NType::PREFIX) { - next_node = prefix.get().ptr; - if (prefix.get().ptr.GetType() == NType::PREFIX) { - prefix.get().ptr.IncreaseBufferId(merge_buffer_count); - prefix = Node::RefMutable(art, next_node, NType::PREFIX); + while (next.GetType() == PREFIX) { + next = *prefix.ptr; + if (prefix.ptr->GetType() == PREFIX) { + prefix.ptr->IncreaseBufferId(buffer_count); + prefix = Prefix(art, next, true); } } - node.IncreaseBufferId(merge_buffer_count); - prefix.get().ptr.InitializeMerge(art, flags); + node.IncreaseBufferId(buffer_count); + prefix.ptr->InitMerge(art, upper_bounds); } -void Prefix::Concatenate(ART &art, Node &prefix_node, const uint8_t byte, Node &child_prefix_node) { +void Prefix::Concat(ART &art, Node &parent, uint8_t byte, const GateStatus old_status, const Node &child, + const GateStatus status) { + D_ASSERT(!parent.IsAnyLeaf()); + D_ASSERT(child.HasMetadata()); - D_ASSERT(prefix_node.HasMetadata() && child_prefix_node.HasMetadata()); - - // append a byte and a child_prefix to prefix - if (prefix_node.GetType() == NType::PREFIX) { - - // get the tail - reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - D_ASSERT(prefix.get().ptr.HasMetadata()); + if (old_status == GateStatus::GATE_SET) { + // Concat Node4. + D_ASSERT(status == GateStatus::GATE_SET); + return ConcatGate(art, parent, byte, child); + } + if (child.GetGateStatus() == GateStatus::GATE_SET) { + // Concat Node4. + D_ASSERT(status == GateStatus::GATE_NOT_SET); + return ConcatChildIsGate(art, parent, byte, child); + } - while (prefix.get().ptr.GetType() == NType::PREFIX) { - prefix = Node::RefMutable(art, prefix.get().ptr, NType::PREFIX); - D_ASSERT(prefix.get().ptr.HasMetadata()); + if (status == GateStatus::GATE_SET && child.GetType() == NType::LEAF_INLINED) { + auto row_id = child.GetRowId(); + if (parent.GetType() == PREFIX) { + auto parent_status = parent.GetGateStatus(); + Free(art, parent); + Leaf::New(parent, row_id); + parent.SetGateStatus(parent_status); + } else { + Leaf::New(parent, row_id); } + return; + } - // append the byte - prefix = prefix.get().Append(art, byte); - - if (child_prefix_node.GetType() == NType::PREFIX) { - // append the child prefix - prefix.get().Append(art, child_prefix_node); + if (parent.GetType() != PREFIX) { + auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); + if (child.GetType() == PREFIX) { + prefix.Append(art, child); } else { - // set child_prefix_node to succeed prefix - prefix.get().ptr = child_prefix_node; + *prefix.ptr = child; } return; } - // create a new prefix node containing the byte, then append the child_prefix to it - if (prefix_node.GetType() != NType::PREFIX && child_prefix_node.GetType() == NType::PREFIX) { + auto tail = GetTail(art, parent); + tail = tail.Append(art, byte); - auto child_prefix = child_prefix_node; - auto &prefix = New(art, prefix_node, byte); - prefix.Append(art, child_prefix); - return; + if (child.GetType() == PREFIX) { + tail.Append(art, child); + } else { + *tail.ptr = child; } - - // neither prefix nor child_prefix are prefix nodes - // create a new prefix containing the byte - New(art, prefix_node, byte, child_prefix_node); } -idx_t Prefix::Traverse(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { +template +idx_t TraverseInternal(ART &art, reference &node, const ARTKey &key, idx_t &depth, + const bool is_mutable = false) { + D_ASSERT(node.get().HasMetadata()); + D_ASSERT(node.get().GetType() == NType::PREFIX); - D_ASSERT(prefix_node.get().HasMetadata()); - D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); + while (node.get().GetType() == NType::PREFIX) { + auto pos = Prefix::GetMismatchWithKey(art, node, key, depth); + if (pos != DConstants::INVALID_INDEX) { + return pos; + } - // compare prefix nodes to key bytes - while (prefix_node.get().GetType() == NType::PREFIX) { - auto &prefix = Node::Ref(art, prefix_node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - if (prefix.data[i] != key[depth]) { - return i; - } - depth++; + Prefix prefix(art, node, is_mutable); + node = *prefix.ptr; + if (node.get().GetGateStatus() == GateStatus::GATE_SET) { + break; } - prefix_node = prefix.ptr; - D_ASSERT(prefix_node.get().HasMetadata()); } - return DConstants::INVALID_INDEX; } -idx_t Prefix::TraverseMutable(ART &art, reference &prefix_node, const ARTKey &key, idx_t &depth) { - - D_ASSERT(prefix_node.get().HasMetadata()); - D_ASSERT(prefix_node.get().GetType() == NType::PREFIX); - - // compare prefix nodes to key bytes - while (prefix_node.get().GetType() == NType::PREFIX) { - auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { - if (prefix.data[i] != key[depth]) { - return i; - } - depth++; - } - prefix_node = prefix.ptr; - D_ASSERT(prefix_node.get().HasMetadata()); - } +idx_t Prefix::Traverse(ART &art, reference &node, const ARTKey &key, idx_t &depth) { + return TraverseInternal(art, node, key, depth); +} - return DConstants::INVALID_INDEX; +idx_t Prefix::TraverseMutable(ART &art, reference &node, const ARTKey &key, idx_t &depth) { + return TraverseInternal(art, node, key, depth, true); } -bool Prefix::Traverse(ART &art, reference &l_node, reference &r_node, idx_t &mismatch_position) { +bool Prefix::Traverse(ART &art, reference &l_node, reference &r_node, idx_t &pos, const GateStatus status) { + D_ASSERT(l_node.get().HasMetadata()); + D_ASSERT(r_node.get().HasMetadata()); - auto &l_prefix = Node::RefMutable(art, l_node.get(), NType::PREFIX); - auto &r_prefix = Node::RefMutable(art, r_node.get(), NType::PREFIX); + Prefix l_prefix(art, l_node, true); + Prefix r_prefix(art, r_node, true); - // compare prefix bytes - idx_t max_count = MinValue(l_prefix.data[Node::PREFIX_SIZE], r_prefix.data[Node::PREFIX_SIZE]); - for (idx_t i = 0; i < max_count; i++) { - if (l_prefix.data[i] != r_prefix.data[i]) { - mismatch_position = i; - break; - } + idx_t max_count = MinValue(l_prefix.data[Count(art)], r_prefix.data[Count(art)]); + pos = GetMismatchWithOther(l_prefix, r_prefix, max_count); + if (pos != DConstants::INVALID_INDEX) { + return true; } - if (mismatch_position == DConstants::INVALID_INDEX) { - - // prefixes match (so far) - if (l_prefix.data[Node::PREFIX_SIZE] == r_prefix.data[Node::PREFIX_SIZE]) { - return l_prefix.ptr.ResolvePrefixes(art, r_prefix.ptr); - } - - mismatch_position = max_count; - - // l_prefix contains r_prefix - if (r_prefix.ptr.GetType() != NType::PREFIX && r_prefix.data[Node::PREFIX_SIZE] == max_count) { - swap(l_node.get(), r_node.get()); - l_node = r_prefix.ptr; - - } else { - // r_prefix contains l_prefix - l_node = l_prefix.ptr; - } + // Match. + if (l_prefix.data[Count(art)] == r_prefix.data[Count(art)]) { + auto r_child = *r_prefix.ptr; + r_prefix.ptr->Clear(); + Node::Free(art, r_node); + return l_prefix.ptr->MergeInternal(art, r_child, status); } + pos = max_count; + if (r_prefix.ptr->GetType() != PREFIX && r_prefix.data[Count(art)] == max_count) { + // l_prefix contains r_prefix. + swap(l_node.get(), r_node.get()); + l_node = *r_prefix.ptr; + return true; + } + // r_prefix contains l_prefix. + l_node = *l_prefix.ptr; return true; } -void Prefix::Reduce(ART &art, Node &prefix_node, const idx_t n) { - - D_ASSERT(prefix_node.HasMetadata()); - D_ASSERT(n < Node::PREFIX_SIZE); - - reference prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); +void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { + D_ASSERT(node.HasMetadata()); + D_ASSERT(pos < Count(art)); - // free this prefix node - if (n == (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)) { - auto next_ptr = prefix.get().ptr; - D_ASSERT(next_ptr.HasMetadata()); - prefix.get().ptr.Clear(); - Node::Free(art, prefix_node); - prefix_node = next_ptr; + Prefix prefix(art, node); + if (pos == idx_t(prefix.data[Count(art)] - 1)) { + auto next = *prefix.ptr; + prefix.ptr->Clear(); + Node::Free(art, node); + node = next; return; } - // shift by n bytes in the current prefix - for (idx_t i = 0; i < Node::PREFIX_SIZE - n - 1; i++) { - prefix.get().data[i] = prefix.get().data[n + i + 1]; + for (idx_t i = 0; i < Count(art) - pos - 1; i++) { + prefix.data[i] = prefix.data[pos + i + 1]; } - D_ASSERT(n < (idx_t)(prefix.get().data[Node::PREFIX_SIZE] - 1)); - prefix.get().data[Node::PREFIX_SIZE] -= n + 1; - // append the remaining prefix bytes - prefix.get().Append(art, prefix.get().ptr); + prefix.data[Count(art)] -= pos + 1; + prefix.Append(art, *prefix.ptr); } -void Prefix::Split(ART &art, reference &prefix_node, Node &child_node, idx_t position) { +GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uint8_t pos) { + D_ASSERT(node.get().HasMetadata()); - D_ASSERT(prefix_node.get().HasMetadata()); + Prefix prefix(art, node, true); - auto &prefix = Node::RefMutable(art, prefix_node, NType::PREFIX); - - // the split is at the last byte of this prefix, so the child_node contains all subsequent - // prefix nodes (prefix.ptr) (if any), and the count of this prefix decreases by one, - // then, we reference prefix.ptr, to overwrite it with a new node later - if (position + 1 == Node::PREFIX_SIZE) { - prefix.data[Node::PREFIX_SIZE]--; - prefix_node = prefix.ptr; - child_node = prefix.ptr; - return; + // The split is at the last prefix byte. Decrease the count and return. + if (pos + 1 == Count(art)) { + prefix.data[Count(art)]--; + node = *prefix.ptr; + child = *prefix.ptr; + return GateStatus::GATE_NOT_SET; } - // append the remaining bytes after the split - if (position + 1 < prefix.data[Node::PREFIX_SIZE]) { - reference child_prefix = New(art, child_node); - for (idx_t i = position + 1; i < prefix.data[Node::PREFIX_SIZE]; i++) { - child_prefix = child_prefix.get().Append(art, prefix.data[i]); - } - - D_ASSERT(prefix.ptr.HasMetadata()); + if (pos + 1 < prefix.data[Count(art)]) { + // Create a new prefix and + // 1. copy the remaining bytes of this prefix. + // 2. append remaining prefix nodes. + auto new_prefix = NewInternal(art, child, nullptr, 0, 0, PREFIX); + new_prefix.data[Count(art)] = prefix.data[Count(art)] - pos - 1; + memcpy(new_prefix.data, prefix.data + pos + 1, new_prefix.data[Count(art)]); - if (prefix.ptr.GetType() == NType::PREFIX) { - child_prefix.get().Append(art, prefix.ptr); + if (prefix.ptr->GetType() == PREFIX && prefix.ptr->GetGateStatus() == GateStatus::GATE_NOT_SET) { + new_prefix.Append(art, *prefix.ptr); } else { - // this is the last prefix node of the prefix - child_prefix.get().ptr = prefix.ptr; + *new_prefix.ptr = *prefix.ptr; } + + } else if (pos + 1 == prefix.data[Count(art)]) { + // No prefix bytes after the split. + child = *prefix.ptr; } - // this is the last prefix node of the prefix - if (position + 1 == prefix.data[Node::PREFIX_SIZE]) { - child_node = prefix.ptr; + // Set the new count of this node. + prefix.data[Count(art)] = pos; + + // No bytes left before the split, free this node. + if (pos == 0) { + auto old_status = node.get().GetGateStatus(); + prefix.ptr->Clear(); + Node::Free(art, node); + return old_status; } - // set the new size of this node - prefix.data[Node::PREFIX_SIZE] = UnsafeNumericCast(position); + // There are bytes left before the split. + // The subsequent node replaces the split byte. + node = *prefix.ptr; + return GateStatus::GATE_NOT_SET; +} - // no bytes left before the split, free this node - if (position == 0) { - prefix.ptr.Clear(); - Node::Free(art, prefix_node.get()); - return; +bool Prefix::Insert(ART &art, Node &node, const ARTKey &key, idx_t depth, const ARTKey &row_id, + const GateStatus status) { + reference next(node); + auto pos = TraverseMutable(art, next, key, depth); + + // We recurse into the next node, if + // (1) the prefix matches the key. + // (2) we reach a gate. + if (pos == DConstants::INVALID_INDEX) { + if (next.get().GetType() != NType::PREFIX || next.get().GetGateStatus() == GateStatus::GATE_SET) { + return art.Insert(next, key, depth, row_id, status); + } + } + + Node remainder; + auto byte = GetByte(art, next, UnsafeNumericCast(pos)); + auto split_status = Split(art, next, remainder, UnsafeNumericCast(pos)); + Node4::New(art, next); + next.get().SetGateStatus(split_status); + + // Insert the remaining prefix into the new Node4. + Node4::InsertChild(art, next, byte, remainder); + + if (status == GateStatus::GATE_SET) { + D_ASSERT(pos != ROW_ID_COUNT); + Node new_row_id; + Leaf::New(new_row_id, key.GetRowId()); + Node::InsertChild(art, next, key[depth], new_row_id); + return true; } - // bytes left before the split, reference subsequent node - prefix_node = prefix.ptr; - return; + Node leaf; + reference ref(leaf); + if (depth + 1 < key.len) { + // Create the prefix. + auto count = key.len - depth - 1; + Prefix::New(art, ref, key, depth + 1, count); + } + // Create the inlined leaf. + Leaf::New(ref, row_id.GetRowId()); + Node4::InsertChild(art, next, key[depth], leaf); + return true; } string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { - - // NOTE: we could do this recursively, but the function-call overhead can become kinda crazy string str = ""; + reference ref(node); - reference node_ref(node); - while (node_ref.get().GetType() == NType::PREFIX) { + Iterator(art, ref, true, false, [&](Prefix &prefix) { + D_ASSERT(prefix.data[Count(art)] != 0); + D_ASSERT(prefix.data[Count(art)] <= Count(art)); - auto &prefix = Node::Ref(art, node_ref, NType::PREFIX); - D_ASSERT(prefix.data[Node::PREFIX_SIZE] != 0); - D_ASSERT(prefix.data[Node::PREFIX_SIZE] <= Node::PREFIX_SIZE); - - str += " prefix_bytes:["; - for (idx_t i = 0; i < prefix.data[Node::PREFIX_SIZE]; i++) { + str += " Prefix :[ "; + for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { str += to_string(prefix.data[i]) + "-"; } - str += "] "; + str += " ] "; + }); - node_ref = prefix.ptr; - } + auto child = ref.get().VerifyAndToString(art, only_verify); + return only_verify ? "" : str + child; +} - auto subtree = node_ref.get().VerifyAndToString(art, only_verify); - return only_verify ? "" : str + subtree; +void Prefix::VerifyAllocations(ART &art, const Node &node, unordered_map &node_counts) { + auto idx = Node::GetAllocatorIdx(PREFIX); + reference ref(node); + Iterator(art, ref, false, false, [&](Prefix &prefix) { node_counts[idx]++; }); + return ref.get().VerifyAllocations(art, node_counts); } -void Prefix::Vacuum(ART &art, Node &node, const ARTFlags &flags) { +void Prefix::Vacuum(ART &art, Node &node, const unordered_set &indexes) { + bool set = indexes.find(Node::GetAllocatorIdx(PREFIX)) != indexes.end(); + auto &allocator = Node::GetAllocator(art, PREFIX); + + reference ref(node); + while (ref.get().GetType() == PREFIX) { + if (set && allocator.NeedsVacuum(ref)) { + auto status = ref.get().GetGateStatus(); + ref.get() = allocator.VacuumPointer(ref); + ref.get().SetMetadata(static_cast(PREFIX)); + ref.get().SetGateStatus(status); + } + Prefix prefix(art, ref, true); + ref = *prefix.ptr; + } - bool flag_set = flags.vacuum_flags[static_cast(NType::PREFIX) - 1]; - auto &allocator = Node::GetAllocator(art, NType::PREFIX); + ref.get().Vacuum(art, indexes); +} - reference node_ref(node); - while (node_ref.get().GetType() == NType::PREFIX) { - if (flag_set && allocator.NeedsVacuum(node_ref)) { - node_ref.get() = allocator.VacuumPointer(node_ref); - node_ref.get().SetMetadata(static_cast(NType::PREFIX)); +void Prefix::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { + // Early-out, if we do not need any transformations. + if (!allocator) { + reference ref(node); + while (ref.get().GetType() == PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { + Prefix prefix(art, ref, true, true); + if (!prefix.in_memory) { + return; + } + ref = *prefix.ptr; } - auto &prefix = Node::RefMutable(art, node_ref, NType::PREFIX); - node_ref = prefix.ptr; + return Node::TransformToDeprecated(art, ref, allocator); } - node_ref.get().Vacuum(art, flags); -} + // Fast path. + if (art.prefix_count <= DEPRECATED_COUNT) { + reference ref(node); + while (ref.get().GetType() == PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { + Prefix prefix(art, ref, true, true); + if (!prefix.in_memory) { + return; + } + + Node new_node; + new_node = allocator->New(); + new_node.SetMetadata(static_cast(PREFIX)); -Prefix &Prefix::Append(ART &art, const uint8_t byte) { + Prefix new_prefix(allocator, new_node, DEPRECATED_COUNT); + new_prefix.data[DEPRECATED_COUNT] = prefix.data[Count(art)]; + memcpy(new_prefix.data, prefix.data, new_prefix.data[DEPRECATED_COUNT]); + *new_prefix.ptr = *prefix.ptr; - reference prefix(*this); + prefix.ptr->Clear(); + Node::Free(art, ref); + ref.get() = new_node; + ref = *new_prefix.ptr; + } + + return Node::TransformToDeprecated(art, ref, allocator); + } + + // Else, we need to create a new prefix chain. + Node new_node; + new_node = allocator->New(); + new_node.SetMetadata(static_cast(PREFIX)); + Prefix new_prefix(allocator, new_node, DEPRECATED_COUNT); + + reference ref(node); + while (ref.get().GetType() == PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { + Prefix prefix(art, ref, true, true); + if (!prefix.in_memory) { + return; + } + + for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { + new_prefix = new_prefix.TransformToDeprecatedAppend(art, allocator, prefix.data[i]); + } - // we need a new prefix node - if (prefix.get().data[Node::PREFIX_SIZE] == Node::PREFIX_SIZE) { - prefix = New(art, prefix.get().ptr); + *new_prefix.ptr = *prefix.ptr; + Node::GetAllocator(art, PREFIX).Free(ref); + ref = *new_prefix.ptr; } - prefix.get().data[prefix.get().data[Node::PREFIX_SIZE]] = byte; - prefix.get().data[Node::PREFIX_SIZE]++; - return prefix.get(); + return Node::TransformToDeprecated(art, ref, allocator); } -void Prefix::Append(ART &art, Node other_prefix) { +Prefix Prefix::Append(ART &art, const uint8_t byte) { + if (data[Count(art)] != Count(art)) { + data[data[Count(art)]] = byte; + data[Count(art)]++; + return *this; + } + + auto prefix = NewInternal(art, *ptr, nullptr, 0, 0, PREFIX); + return prefix.Append(art, byte); +} - D_ASSERT(other_prefix.HasMetadata()); +void Prefix::Append(ART &art, Node other) { + D_ASSERT(other.HasMetadata()); - reference prefix(*this); - while (other_prefix.GetType() == NType::PREFIX) { + Prefix prefix = *this; + while (other.GetType() == PREFIX) { + if (other.GetGateStatus() == GateStatus::GATE_SET) { + *prefix.ptr = other; + return; + } - // copy prefix bytes - auto &other = Node::RefMutable(art, other_prefix, NType::PREFIX); - for (idx_t i = 0; i < other.data[Node::PREFIX_SIZE]; i++) { - prefix = prefix.get().Append(art, other.data[i]); + Prefix other_prefix(art, other, true); + for (idx_t i = 0; i < other_prefix.data[Count(art)]; i++) { + prefix = prefix.Append(art, other_prefix.data[i]); } - D_ASSERT(other.ptr.HasMetadata()); + *prefix.ptr = *other_prefix.ptr; + Node::GetAllocator(art, PREFIX).Free(other); + other = *prefix.ptr; + } +} + +Prefix Prefix::GetTail(ART &art, const Node &node) { + Prefix prefix(art, node, true); + while (prefix.ptr->GetType() == PREFIX) { + prefix = Prefix(art, *prefix.ptr, true); + } + return prefix; +} + +void Prefix::ConcatGate(ART &art, Node &parent, uint8_t byte, const Node &child) { + D_ASSERT(child.HasMetadata()); + Node new_prefix = Node(); + + // Inside gates, inlined row IDs are not prefixed. + if (child.GetType() == NType::LEAF_INLINED) { + Leaf::New(new_prefix, child.GetRowId()); + + } else if (child.GetType() == PREFIX) { + // At least one more row ID in this gate. + auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); + prefix.ptr->Clear(); + prefix.Append(art, child); + new_prefix.SetGateStatus(GateStatus::GATE_SET); + + } else { + // At least one more row ID in this gate. + auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); + *prefix.ptr = child; + new_prefix.SetGateStatus(GateStatus::GATE_SET); + } + + if (parent.GetType() != PREFIX) { + parent = new_prefix; + return; + } + *GetTail(art, parent).ptr = new_prefix; +} + +void Prefix::ConcatChildIsGate(ART &art, Node &parent, uint8_t byte, const Node &child) { + // Create a new prefix and point it to the gate. + if (parent.GetType() != PREFIX) { + auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); + *prefix.ptr = child; + return; + } + + auto tail = GetTail(art, parent); + tail = tail.Append(art, byte); + *tail.ptr = child; +} - prefix.get().ptr = other.ptr; - Node::GetAllocator(art, NType::PREFIX).Free(other_prefix); - other_prefix = prefix.get().ptr; +Prefix Prefix::TransformToDeprecatedAppend(ART &art, unsafe_unique_ptr &allocator, uint8_t byte) { + if (data[DEPRECATED_COUNT] != DEPRECATED_COUNT) { + data[data[DEPRECATED_COUNT]] = byte; + data[DEPRECATED_COUNT]++; + return *this; } - D_ASSERT(prefix.get().ptr.GetType() != NType::PREFIX); + *ptr = allocator->New(); + ptr->SetMetadata(static_cast(PREFIX)); + Prefix prefix(allocator, *ptr, DEPRECATED_COUNT); + return prefix.TransformToDeprecatedAppend(art, allocator, byte); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp index 49a02a05..017f7f5b 100644 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ b/src/duckdb/src/execution/index/bound_index.cpp @@ -62,6 +62,12 @@ string BoundIndex::VerifyAndToString(const bool only_verify) { return VerifyAndToString(state, only_verify); } +void BoundIndex::VerifyAllocations() { + IndexLock state; + InitializeLock(state); + return VerifyAllocations(state); +} + void BoundIndex::Vacuum() { IndexLock state; InitializeLock(state); @@ -97,7 +103,7 @@ bool BoundIndex::IndexIsUpdated(const vector &column_ids_p) const return false; } -IndexStorageInfo BoundIndex::GetStorageInfo(const bool get_buffers) { +IndexStorageInfo BoundIndex::GetStorageInfo(const case_insensitive_map_t &options, const bool to_wal) { throw NotImplementedException("The implementation of this index serialization does not exist."); } diff --git a/src/duckdb/src/execution/index/fixed_size_allocator.cpp b/src/duckdb/src/execution/index/fixed_size_allocator.cpp index 32018b3e..a6ad0f38 100644 --- a/src/duckdb/src/execution/index/fixed_size_allocator.cpp +++ b/src/duckdb/src/execution/index/fixed_size_allocator.cpp @@ -8,9 +8,9 @@ FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &b : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), segment_size(segment_size), total_segment_count(0) { - if (segment_size > Storage::BLOCK_SIZE - sizeof(validity_t)) { + if (segment_size > block_manager.GetBlockSize() - sizeof(validity_t)) { throw InternalException("The maximum segment size of fixed-size allocators is " + - to_string(Storage::BLOCK_SIZE - sizeof(validity_t))); + to_string(block_manager.GetBlockSize() - sizeof(validity_t))); } // calculate how many segments fit into one buffer (available_segments_per_buffer) @@ -21,7 +21,7 @@ FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &b bitmask_count = 0; available_segments_per_buffer = 0; - while (byte_count < Storage::BLOCK_SIZE) { + while (byte_count < block_manager.GetBlockSize()) { if (!bitmask_count || (bitmask_count * bits_per_value) % available_segments_per_buffer == 0) { // we need to add another validity_t value to the bitmask, to allow storing another // bits_per_value segments on a buffer @@ -29,7 +29,7 @@ FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &b byte_count += sizeof(validity_t); } - auto remaining_bytes = Storage::BLOCK_SIZE - byte_count; + auto remaining_bytes = block_manager.GetBlockSize() - byte_count; auto remaining_segments = MinValue(remaining_bytes / segment_size, bits_per_value); if (remaining_segments == 0) { @@ -126,7 +126,7 @@ idx_t FixedSizeAllocator::GetInMemorySize() const { idx_t memory_usage = 0; for (auto &buffer : buffers) { if (buffer.second.InMemory()) { - memory_usage += Storage::BLOCK_SIZE; + memory_usage += block_manager.GetBlockSize(); } } return memory_usage; @@ -172,18 +172,7 @@ bool FixedSizeAllocator::InitializeVacuum() { Reset(); return false; } - - // remove all empty buffers - auto buffer_it = buffers.begin(); - while (buffer_it != buffers.end()) { - if (!buffer_it->second.segment_count) { - buffers_with_free_space.erase(buffer_it->first); - buffer_it->second.Destroy(); - buffer_it = buffers.erase(buffer_it); - } else { - buffer_it++; - } - } + RemoveEmptyBuffers(); // determine if a vacuum is necessary multimap temporary_vacuum_buffers; @@ -209,7 +198,7 @@ bool FixedSizeAllocator::InitializeVacuum() { // calculate the vacuum threshold adaptively D_ASSERT(excess_buffer_count < temporary_vacuum_buffers.size()); idx_t memory_usage = GetInMemorySize(); - idx_t excess_memory_usage = excess_buffer_count * Storage::BLOCK_SIZE; + idx_t excess_memory_usage = excess_buffer_count * block_manager.GetBlockSize(); auto excess_percentage = double(excess_memory_usage) / double(memory_usage); auto threshold = double(VACUUM_THRESHOLD) / 100.0; if (excess_percentage < threshold) { @@ -355,4 +344,19 @@ idx_t FixedSizeAllocator::GetAvailableBufferId() const { return buffer_id; } +void FixedSizeAllocator::RemoveEmptyBuffers() { + + auto buffer_it = buffers.begin(); + while (buffer_it != buffers.end()) { + if (buffer_it->second.segment_count != 0) { + buffer_it++; + continue; + } + + buffers_with_free_space.erase(buffer_it->first); + buffer_it->second.Destroy(); + buffer_it = buffers.erase(buffer_it); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/index/fixed_size_buffer.cpp b/src/duckdb/src/execution/index/fixed_size_buffer.cpp index ab5f4f1d..29bb40f7 100644 --- a/src/duckdb/src/execution/index/fixed_size_buffer.cpp +++ b/src/duckdb/src/execution/index/fixed_size_buffer.cpp @@ -40,7 +40,7 @@ FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) block_handle(nullptr) { auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, Storage::BLOCK_SIZE, false, &block_handle); + buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.GetBlockSize(), false, &block_handle); } FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, @@ -68,7 +68,7 @@ void FixedSizeBuffer::Destroy() { void FixedSizeBuffer::Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, const idx_t segment_size, const idx_t bitmask_offset) { - // we do not serialize a block that is already on disk and not in memory + // Early-out, if the block is already on disk and not in memory. if (!InMemory()) { if (!OnDisk() || dirty) { throw InternalException("invalid or missing buffer in FixedSizeAllocator"); @@ -76,16 +76,20 @@ void FixedSizeBuffer::Serialize(PartialBlockManager &partial_block_manager, cons return; } - // we do not serialize a block that is already on disk and not dirty + // Early-out, if the buffer is already on disk and not dirty. if (!dirty && OnDisk()) { return; } - // the allocation possibly changed + // Adjust the allocation size. + D_ASSERT(segment_count != 0); SetAllocationSize(available_segments, segment_size, bitmask_offset); // the buffer is in memory, so we copied it onto a new buffer when pinning - D_ASSERT(InMemory() && !OnDisk()); + D_ASSERT(InMemory()); + if (OnDisk()) { + block_manager.MarkBlockAsModified(block_pointer.block_id); + } // now we write the changes, first get a partial block allocation PartialBlockAllocation allocation = @@ -131,17 +135,14 @@ void FixedSizeBuffer::Pin() { buffer_handle = buffer_manager.Pin(block_handle); - // we need to copy the (partial) data into a new (not yet disk-backed) buffer handle + // Copy the (partial) data into a new (not yet disk-backed) buffer handle. shared_ptr new_block_handle; auto new_buffer_handle = - buffer_manager.Allocate(MemoryTag::ART_INDEX, Storage::BLOCK_SIZE, false, &new_block_handle); - + buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.GetBlockSize(), false, &new_block_handle); memcpy(new_buffer_handle.Ptr(), buffer_handle.Ptr() + block_pointer.offset, allocation_size); - Destroy(); buffer_handle = std::move(new_buffer_handle); block_handle = std::move(new_block_handle); - block_pointer = BlockPointer(); } uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count) { @@ -195,75 +196,23 @@ uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count) { void FixedSizeBuffer::SetAllocationSize(const idx_t available_segments, const idx_t segment_size, const idx_t bitmask_offset) { - - if (dirty) { - auto max_offset = GetMaxOffset(available_segments); - allocation_size = max_offset * segment_size + bitmask_offset; - } -} - -uint32_t FixedSizeBuffer::GetMaxOffset(const idx_t available_segments) { - - // this function calls Get() on the buffer - D_ASSERT(InMemory()); - - // finds the maximum zero bit in a bitmask, and adds one to it, - // so that max_offset * segment_size = allocated_size of this bitmask's buffer - idx_t entry_size = sizeof(validity_t) * 8; - idx_t bitmask_count = available_segments / entry_size; - if (available_segments % entry_size != 0) { - bitmask_count++; + if (!dirty) { + return; } - auto max_offset = UnsafeNumericCast(bitmask_count * sizeof(validity_t) * 8); - auto bits_in_last_entry = available_segments % (sizeof(validity_t) * 8); - // get the bitmask data + // We traverse from the back. A binary search would be faster. + // However, buffers are often (almost) full, so the overhead is acceptable. auto bitmask_ptr = reinterpret_cast(Get()); - const ValidityMask mask(bitmask_ptr); - const auto data = mask.GetData(); - - D_ASSERT(bitmask_count > 0); - for (idx_t i = bitmask_count; i > 0; i--) { - - auto entry = data[i - 1]; - - // set all bits after bits_in_last_entry - if (i == bitmask_count) { - entry |= ~idx_t(0) << bits_in_last_entry; - } - - if (entry == ~idx_t(0)) { - max_offset -= sizeof(validity_t) * 8; - continue; - } - - // invert data[entry_idx] - auto entry_inv = ~entry; - idx_t first_valid_bit = 0; - - // then find the position of the LEFTMOST set bit - for (idx_t level = 0; level < 6; level++) { + ValidityMask mask(bitmask_ptr); - // set the right half of the bits of this level to zero and test if the entry is still not zero - if (entry_inv & ~BASE[level]) { - // first valid bit is in the leftmost s[level] bits - // shift by s[level] for the next iteration and add s[level] to the position of the leftmost set bit - entry_inv >>= SHIFT[level]; - first_valid_bit += SHIFT[level]; - } else { - // first valid bit is in the rightmost s[level] bits - // permanently set the left half of the bits to zero - entry_inv &= BASE[level]; - } + auto max_offset = available_segments; + for (idx_t i = available_segments; i > 0; i--) { + if (!mask.RowIsValid(i - 1)) { + max_offset = i; + break; } - D_ASSERT(entry_inv); - max_offset -= sizeof(validity_t) * 8 - first_valid_bit; - D_ASSERT(!mask.RowIsValid(max_offset)); - return max_offset + 1; } - - // there are no allocations in this buffer - throw InternalException("tried to serialize empty buffer"); + allocation_size = max_offset * segment_size + bitmask_offset; } void FixedSizeBuffer::SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index c1f27391..e19d2a7e 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -1,26 +1,42 @@ #include "duckdb/execution/join_hashtable.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/ht_entry.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { - using ValidityBytes = JoinHashTable::ValidityBytes; using ScanStructure = JoinHashTable::ScanStructure; using ProbeSpill = JoinHashTable::ProbeSpill; using ProbeSpillLocalState = JoinHashTable::ProbeSpillLocalAppendState; -JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector &conditions_p, - vector btypes, JoinType type_p, const vector &output_columns_p) - : buffer_manager(buffer_manager_p), conditions(conditions_p), build_types(std::move(btypes)), - output_columns(output_columns_p), entry_size(0), tuple_size(0), vfound(Value::BOOLEAN(false)), join_type(type_p), - finalized(false), has_null(false), radix_bits(INITIAL_RADIX_BITS), partition_start(0), partition_end(0) { +JoinHashTable::SharedState::SharedState() + : rhs_row_locations(LogicalType::POINTER), salt_match_sel(STANDARD_VECTOR_SIZE), + key_no_match_sel(STANDARD_VECTOR_SIZE) { +} - for (auto &condition : conditions) { +JoinHashTable::ProbeState::ProbeState() + : SharedState(), salt_v(LogicalType::UBIGINT), ht_offsets_v(LogicalType::UBIGINT), + ht_offsets_dense_v(LogicalType::UBIGINT), non_empty_sel(STANDARD_VECTOR_SIZE) { +} + +JoinHashTable::InsertState::InsertState(const JoinHashTable &ht) + : SharedState(), remaining_sel(STANDARD_VECTOR_SIZE), key_match_sel(STANDARD_VECTOR_SIZE) { + ht.data_collection->InitializeChunk(lhs_data, ht.equality_predicate_columns); + ht.data_collection->InitializeChunkState(chunk_state, ht.equality_predicate_columns); +} + +JoinHashTable::JoinHashTable(ClientContext &context, const vector &conditions_p, + vector btypes, JoinType type_p, const vector &output_columns_p) + : buffer_manager(BufferManager::GetBufferManager(context)), conditions(conditions_p), + build_types(std::move(btypes)), output_columns(output_columns_p), entry_size(0), tuple_size(0), + vfound(Value::BOOLEAN(false)), join_type(type_p), finalized(false), has_null(false), + radix_bits(INITIAL_RADIX_BITS), partition_start(0), partition_end(0) { + for (idx_t i = 0; i < conditions.size(); ++i) { + auto &condition = conditions[i]; D_ASSERT(condition.left->return_type == condition.right->return_type); auto type = condition.left->return_type; if (condition.comparison == ExpressionType::COMPARE_EQUAL || @@ -30,9 +46,15 @@ JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector(new RowMatcher()); + row_matcher_probe_no_match_sel = unique_ptr(new RowMatcher()); + + row_matcher_probe->Initialize(false, layout, non_equality_predicates, non_equality_predicate_columns); + row_matcher_probe_no_match_sel->Initialize(true, layout, non_equality_predicates, + non_equality_predicate_columns); + + needs_chain_matcher = true; + } else { + needs_chain_matcher = false; + } + + chains_longer_than_one = false; + row_matcher_build.Initialize(true, layout, equality_predicates); const auto &offsets = layout.GetOffsets(); tuple_size = offsets[condition_types.size() + build_types.size()]; @@ -62,6 +100,14 @@ JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector(buffer_manager, layout); sink_collection = make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); + + dead_end = make_unsafe_uniq_array_uninitialized(layout.GetRowWidth()); + memset(dead_end.get(), 0, layout.GetRowWidth()); + + if (join_type == JoinType::SINGLE) { + auto &config = ClientConfig::GetConfig(context); + single_join_error_on_multiple_rows = config.scalar_subquery_error_on_multiple_rows; + } } JoinHashTable::~JoinHashTable() { @@ -86,32 +132,176 @@ void JoinHashTable::Merge(JoinHashTable &other) { sink_collection->Combine(*other.sink_collection); } -void JoinHashTable::ApplyBitmask(Vector &hashes, idx_t count) { - if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - D_ASSERT(!ConstantVector::IsNull(hashes)); - auto indices = ConstantVector::GetData(hashes); - *indices = *indices & bitmask; +static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, const idx_t &count, const idx_t &bitmask) { + if (hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) { + D_ASSERT(!ConstantVector::IsNull(hashes_v)); + auto indices = ConstantVector::GetData(hashes_v); + hash_t salt = ht_entry_t::ExtractSaltWithNulls(*indices); + idx_t offset = *indices & bitmask; + *indices = offset | salt; + hashes_v.Flatten(count); } else { - hashes.Flatten(count); - auto indices = FlatVector::GetData(hashes); + hashes_v.Flatten(count); + auto hashes = FlatVector::GetData(hashes_v); for (idx_t i = 0; i < count; i++) { - indices[i] &= bitmask; + idx_t salt = ht_entry_t::ExtractSaltWithNulls(hashes[i]); + idx_t offset = hashes[i] & bitmask; + hashes[i] = offset | salt; } } } -void JoinHashTable::ApplyBitmask(Vector &hashes, const SelectionVector &sel, idx_t count, Vector &pointers) { - UnifiedVectorFormat hdata; - hashes.ToUnifiedFormat(count, hdata); +//! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the key_match_sel +//! vector and the count argument to the number and position of the matches +template +static inline void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_state, + JoinHashTable::ProbeState &state, Vector &hashes_v, + const SelectionVector &sel, idx_t &count, JoinHashTable *ht, + ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel) { + UnifiedVectorFormat hashes_v_unified; + hashes_v.ToUnifiedFormat(count, hashes_v_unified); + + auto hashes = UnifiedVectorFormat::GetData(hashes_v_unified); + auto salts = FlatVector::GetData(state.salt_v); + + auto ht_offsets = FlatVector::GetData(state.ht_offsets_v); + auto ht_offsets_dense = FlatVector::GetData(state.ht_offsets_dense_v); - auto hash_data = UnifiedVectorFormat::GetData(hdata); - auto result_data = FlatVector::GetData(pointers); - auto main_ht = reinterpret_cast(hash_map.get()); + idx_t non_empty_count = 0; + + // first, filter out the empty rows and calculate the offset + for (idx_t i = 0; i < count; i++) { + const auto row_index = sel.get_index(i); + auto uvf_index = hashes_v_unified.sel->get_index(row_index); + auto ht_offset = hashes[uvf_index] & ht->bitmask; + ht_offsets_dense[i] = ht_offset; + ht_offsets[row_index] = ht_offset; + } + + // have a dense loop to have as few instructions as possible while producing cache misses as this is the + // first location where we access the big entries array for (idx_t i = 0; i < count; i++) { - auto rindex = sel.get_index(i); - auto hindex = hdata.sel->get_index(rindex); - auto hash = hash_data[hindex]; - result_data[rindex] = main_ht + (hash & bitmask); + idx_t ht_offset = ht_offsets_dense[i]; + auto &entry = entries[ht_offset]; + bool occupied = entry.IsOccupied(); + state.non_empty_sel.set_index(non_empty_count, i); + non_empty_count += occupied; + } + + for (idx_t i = 0; i < non_empty_count; i++) { + // transform the dense index to the actual index in the sel vector + idx_t dense_index = state.non_empty_sel.get_index(i); + const auto row_index = sel.get_index(dense_index); + state.non_empty_sel.set_index(i, row_index); + + if (USE_SALTS) { + auto uvf_index = hashes_v_unified.sel->get_index(row_index); + auto hash = hashes[uvf_index]; + hash_t row_salt = ht_entry_t::ExtractSalt(hash); + salts[row_index] = row_salt; + } + } + + auto pointers_result = FlatVector::GetData(pointers_result_v); + auto row_ptr_insert_to = FlatVector::GetData(state.rhs_row_locations); + + const SelectionVector *remaining_sel = &state.non_empty_sel; + idx_t remaining_count = non_empty_count; + + idx_t &match_count = count; + match_count = 0; + + while (remaining_count > 0) { + idx_t salt_match_count = 0; + idx_t key_no_match_count = 0; + + // for each entry, linear probing until + // a) an empty entry is found -> return nullptr (do nothing, as vector is zeroed) + // b) an entry is found where the salt matches -> need to compare the keys + for (idx_t i = 0; i < remaining_count; i++) { + const auto row_index = remaining_sel->get_index(i); + + idx_t &ht_offset = ht_offsets[row_index]; + bool occupied; + ht_entry_t entry; + + if (USE_SALTS) { + hash_t row_salt = salts[row_index]; + // increment the ht_offset of the entry as long as next entry is occupied and salt does not match + while (true) { + entry = entries[ht_offset]; + occupied = entry.IsOccupied(); + bool salt_match = entry.GetSalt() == row_salt; + + // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next + // entry + if (!occupied || salt_match) { + break; + } + + IncrementAndWrap(ht_offset, ht->bitmask); + } + } else { + entry = entries[ht_offset]; + occupied = entry.IsOccupied(); + } + + // the entries we need to process in the next iteration are the ones that are occupied and the row_salt + // does not match, the ones that are empty need no further processing + state.salt_match_sel.set_index(salt_match_count, row_index); + salt_match_count += occupied; + + // entry might be empty, so the pointer in the entry is nullptr, but this does not matter as the row + // will not be compared anyway as with an empty entry we are already done + row_ptr_insert_to[row_index] = entry.GetPointerOrNull(); + } + + if (salt_match_count != 0) { + // Perform row comparisons, after function call salt_match_sel will point to the keys that match + idx_t key_match_count = ht->row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel, + salt_match_count, ht->layout, state.rhs_row_locations, + &state.key_no_match_sel, key_no_match_count); + + D_ASSERT(key_match_count + key_no_match_count == salt_match_count); + + // Set a pointer to the matching row + for (idx_t i = 0; i < key_match_count; i++) { + const auto row_index = state.salt_match_sel.get_index(i); + pointers_result[row_index] = row_ptr_insert_to[row_index]; + + match_sel.set_index(match_count, row_index); + match_count++; + } + + // Linear probing: each of the entries that do not match move to the next entry in the HT + for (idx_t i = 0; i < key_no_match_count; i++) { + const auto row_index = state.key_no_match_sel.get_index(i); + auto &ht_offset = ht_offsets[row_index]; + + IncrementAndWrap(ht_offset, ht->bitmask); + } + } + + remaining_sel = &state.key_no_match_sel; + remaining_count = key_no_match_count; + } +} + +inline bool JoinHashTable::UseSalt() const { + // only use salt for large hash tables and if there is only one equality condition as otherwise + // we potentially need to compare multiple keys + return this->capacity > USE_SALT_THRESHOLD && this->equality_predicate_columns.size() == 1; +} + +void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, + const SelectionVector &sel, idx_t &count, Vector &pointers_result_v, + SelectionVector &match_sel) { + if (UseSalt()) { + GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v, + match_sel); + } else { + GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v, + match_sel); } } @@ -224,84 +414,295 @@ idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector // figure out which keys are NULL, and create a selection vector out of them current_sel = FlatVector::IncrementalSelectionVector(); idx_t added_count = keys.size(); - if (build_side && (PropagatesBuildSide(join_type))) { + if (build_side && PropagatesBuildSide(join_type)) { // in case of a right or full outer join, we cannot remove NULL keys from the build side return added_count; } for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { - if (!null_values_are_equal[col_idx]) { - auto &col_key_data = vector_data[col_idx].unified; - if (col_key_data.validity.AllValid()) { - continue; - } - added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); - // null values are NOT equal for this column, filter them out - current_sel = &sel; + if (null_values_are_equal[col_idx]) { + continue; } + auto &col_key_data = vector_data[col_idx].unified; + if (col_key_data.validity.AllValid()) { + continue; + } + added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); + // null values are NOT equal for this column, filter them out + current_sel = &sel; } return added_count; } -template -static inline void InsertHashesLoop(atomic pointers[], const hash_t indices[], const idx_t count, - const data_ptr_t key_locations[], const idx_t pointer_offset) { - for (idx_t i = 0; i < count; i++) { - const auto index = indices[i]; - if (PARALLEL) { - data_ptr_t head; +static void StorePointer(const_data_ptr_t pointer, data_ptr_t target) { + Store(cast_pointer_to_uint64(pointer), target); +} + +static data_ptr_t LoadPointer(const_data_ptr_t source) { + return cast_uint64_to_pointer(Load(source)); +} + +//! If we consider to insert into an entry we expct to be empty, if it was filled in the meantime the insert will not +//! happen and we need to return the pointer to the to row with which the new entry would have collided. In any other +//! case we return a nullptr +template +static inline data_ptr_t InsertRowToEntry(atomic &entry, const data_ptr_t &row_ptr_to_insert, + const hash_t &salt, const idx_t &pointer_offset) { + + if (PARALLEL) { + // if we expect the entry to be empty, if the operation fails we need to cancel the whole operation as another + // key might have been inserted in the meantime that does not match the current key + if (EXPECT_EMPTY) { + + // add nullptr to the end of the list to mark the end + StorePointer(nullptr, row_ptr_to_insert + pointer_offset); + + ht_entry_t new_empty_entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt); + ht_entry_t expected_empty_entry = ht_entry_t::GetEmptyEntry(); + std::atomic_compare_exchange_weak(&entry, &expected_empty_entry, new_empty_entry); + + // if the expected empty entry actually was null, we can just return the pointer, and it will be a nullptr + // if the expected entry was filled in the meantime, we need to cancel the operation and will return the + // pointer to the next entry + return expected_empty_entry.GetPointerOrNull(); + } + + // if we expect the entry to be full, we know that even if the insert fails the keys still match so we can + // just keep trying until we succeed + else { + ht_entry_t expected_current_entry = entry.load(std::memory_order_relaxed); + ht_entry_t desired_new_entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt); + D_ASSERT(expected_current_entry.IsOccupied()); + do { - head = pointers[index]; - Store(head, key_locations[i] + pointer_offset); - } while (!std::atomic_compare_exchange_weak(&pointers[index], &head, key_locations[i])); - } else { - // set prev in current key to the value (NOTE: this will be nullptr if there is none) - Store(pointers[index], key_locations[i] + pointer_offset); + data_ptr_t current_row_pointer = expected_current_entry.GetPointer(); + StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset); + } while (!std::atomic_compare_exchange_weak(&entry, &expected_current_entry, desired_new_entry)); - // set pointer to current tuple - pointers[index] = key_locations[i]; + return nullptr; } + } else { + // if we are not in parallel mode, we can just do the operation without any checks + ht_entry_t current_entry = entry.load(std::memory_order_relaxed); + data_ptr_t current_row_pointer = current_entry.GetPointerOrNull(); + StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset); + entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt); + return nullptr; + } +} +static inline void PerformKeyComparison(JoinHashTable::InsertState &state, JoinHashTable &ht, + const TupleDataCollection &data_collection, Vector &row_locations, + const idx_t count, idx_t &key_match_count, idx_t &key_no_match_count) { + // Get the data for the rows that need to be compared + state.lhs_data.Reset(); + state.lhs_data.SetCardinality(count); // the right size + + // The target selection vector says where to write the results into the lhs_data, we just want to write + // sequentially as otherwise we trigger a bug in the Gather function + data_collection.ResetCachedCastVectors(state.chunk_state, ht.equality_predicate_columns); + data_collection.Gather(row_locations, state.salt_match_sel, count, ht.equality_predicate_columns, state.lhs_data, + *FlatVector::IncrementalSelectionVector(), state.chunk_state.cached_cast_vectors); + TupleDataCollection::ToUnifiedFormat(state.chunk_state, state.lhs_data); + + for (idx_t i = 0; i < count; i++) { + state.key_match_sel.set_index(i, i); + } + + // Perform row comparisons + key_match_count = + ht.row_matcher_build.Match(state.lhs_data, state.chunk_state.vector_data, state.key_match_sel, count, ht.layout, + state.rhs_row_locations, &state.key_no_match_sel, key_no_match_count); + + D_ASSERT(key_match_count + key_no_match_count == count); +} + +template +static inline void InsertMatchesAndIncrementMisses(atomic entries[], JoinHashTable::InsertState &state, + JoinHashTable &ht, const data_ptr_t lhs_row_locations[], + idx_t ht_offsets_and_salts[], const idx_t capacity_mask, + const idx_t key_match_count, const idx_t key_no_match_count) { + if (key_match_count != 0) { + ht.chains_longer_than_one = true; + } + + // Insert the rows that match + for (idx_t i = 0; i < key_match_count; i++) { + const auto need_compare_idx = state.key_match_sel.get_index(i); + const auto entry_index = state.salt_match_sel.get_index(need_compare_idx); + + const auto &ht_offset = ht_offsets_and_salts[entry_index] & ht_entry_t::POINTER_MASK; + auto &entry = entries[ht_offset]; + const data_ptr_t row_ptr_to_insert = lhs_row_locations[entry_index]; + + const auto salt = ht_offsets_and_salts[entry_index]; + InsertRowToEntry(entry, row_ptr_to_insert, salt, ht.pointer_offset); + } + + // Linear probing: each of the entries that do not match move to the next entry in the HT + for (idx_t i = 0; i < key_no_match_count; i++) { + const auto need_compare_idx = state.key_no_match_sel.get_index(i); + const auto entry_index = state.salt_match_sel.get_index(need_compare_idx); + + idx_t &ht_offset_and_salt = ht_offsets_and_salts[entry_index]; + IncrementAndWrap(ht_offset_and_salt, capacity_mask); + + state.remaining_sel.set_index(i, entry_index); } } -void JoinHashTable::InsertHashes(Vector &hashes, idx_t count, data_ptr_t key_locations[], bool parallel) { - D_ASSERT(hashes.GetType().id() == LogicalType::HASH); +template +static void InsertHashesLoop(atomic entries[], Vector &row_locations, Vector &hashes_v, const idx_t &count, + JoinHashTable::InsertState &state, const TupleDataCollection &data_collection, + JoinHashTable &ht) { + D_ASSERT(hashes_v.GetType().id() == LogicalType::HASH); + ApplyBitmaskAndGetSaltBuild(hashes_v, count, ht.bitmask); + + // the offset for each row to insert + const auto ht_offsets_and_salts = FlatVector::GetData(hashes_v); + // the row locations of the rows that are already in the hash table + const auto rhs_row_locations = FlatVector::GetData(state.rhs_row_locations); + // the row locations of the rows that are to be inserted + const auto lhs_row_locations = FlatVector::GetData(row_locations); + + // we start off with the entire chunk + idx_t remaining_count = count; + const auto *remaining_sel = FlatVector::IncrementalSelectionVector(); + + if (PropagatesBuildSide(ht.join_type)) { + // if we propagate the build side, we may have added rows with NULL keys to the HT + // these may need to be filtered out depending on the comparison type (exactly like PrepareKeys does) + for (idx_t col_idx = 0; col_idx < ht.conditions.size(); col_idx++) { + // if null values are NOT equal for this column we filter them out + if (ht.NullValuesAreEqual(col_idx)) { + continue; + } + + idx_t entry_idx; + idx_t idx_in_entry; + ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); + + idx_t new_remaining_count = 0; + for (idx_t i = 0; i < remaining_count; i++) { + const auto idx = remaining_sel->get_index(i); + if (ValidityBytes(lhs_row_locations[idx]).RowIsValidUnsafe(col_idx)) { + state.remaining_sel.set_index(new_remaining_count++, idx); + } + } + remaining_count = new_remaining_count; + remaining_sel = &state.remaining_sel; + } + } + + // use the ht bitmask to make the modulo operation faster but keep the salt bits intact + idx_t capacity_mask = ht.bitmask | ht_entry_t::SALT_MASK; + while (remaining_count > 0) { + idx_t salt_match_count = 0; + + // iterate over each entry to find out whether it belongs to an existing list or will start a new list + for (idx_t i = 0; i < remaining_count; i++) { + const idx_t row_index = remaining_sel->get_index(i); + idx_t &ht_offset_and_salt = ht_offsets_and_salts[row_index]; + const hash_t salt = ht_entry_t::ExtractSalt(ht_offset_and_salt); + + // increment the ht_offset_and_salt of the entry as long as next entry is occupied and salt does not match + idx_t ht_offset; + ht_entry_t entry; + bool occupied; + while (true) { + ht_offset = ht_offset_and_salt & ht_entry_t::POINTER_MASK; + atomic &atomic_entry = entries[ht_offset]; + entry = atomic_entry.load(std::memory_order_relaxed); + occupied = entry.IsOccupied(); + + // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next entry + if (!occupied) { + break; + } + if (entry.GetSalt() == salt) { + break; + } + + IncrementAndWrap(ht_offset_and_salt, capacity_mask); + } + + if (!occupied) { // insert into free + auto &atomic_entry = entries[ht_offset]; + const auto row_ptr_to_insert = lhs_row_locations[row_index]; + const auto potential_collided_ptr = + InsertRowToEntry(atomic_entry, row_ptr_to_insert, salt, ht.pointer_offset); + + if (PARALLEL) { + // if the insertion was not successful, the entry was occupied in the meantime, so we have to + // compare the keys and insert the row to the next entry + if (potential_collided_ptr) { + // if the entry was occupied, we need to compare the keys and insert the row to the next entry + // we need to compare the keys and insert the row to the next entry + state.salt_match_sel.set_index(salt_match_count, row_index); + rhs_row_locations[salt_match_count] = potential_collided_ptr; + salt_match_count += 1; + } + } - // use bitmask to get position in array - ApplyBitmask(hashes, count); + } else { // compare with full entry + state.salt_match_sel.set_index(salt_match_count, row_index); + rhs_row_locations[salt_match_count] = entry.GetPointer(); + salt_match_count += 1; + } + } - hashes.Flatten(count); - D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); + // at this step, for all the rows to insert we stepped either until we found an empty entry or an entry with + // a matching salt, we now need to compare the keys for the ones that have a matching salt + idx_t key_no_match_count = 0; + if (salt_match_count != 0) { + idx_t key_match_count = 0; + PerformKeyComparison(state, ht, data_collection, row_locations, salt_match_count, key_match_count, + key_no_match_count); + InsertMatchesAndIncrementMisses(entries, state, ht, lhs_row_locations, ht_offsets_and_salts, + capacity_mask, key_match_count, key_no_match_count); + } - auto pointers = reinterpret_cast *>(hash_map.get()); - auto indices = FlatVector::GetData(hashes); + // update the overall selection vector to only point the entries that still need to be inserted + // as there was no match found for them yet + remaining_sel = &state.remaining_sel; + remaining_count = key_no_match_count; + } +} +void JoinHashTable::InsertHashes(Vector &hashes_v, const idx_t count, TupleDataChunkState &chunk_state, + InsertState &insert_state, bool parallel) { + auto atomic_entries = reinterpret_cast *>(this->entries); + auto row_locations = chunk_state.row_locations; if (parallel) { - InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); + InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); } else { - InsertHashesLoop(pointers, indices, count, key_locations, pointer_offset); + InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); } } void JoinHashTable::InitializePointerTable() { - idx_t capacity = PointerTableCapacity(Count()); + capacity = PointerTableCapacity(Count()); D_ASSERT(IsPowerOfTwo(capacity)); if (hash_map.get()) { // There is already a hash map - auto current_capacity = hash_map.GetSize() / sizeof(data_ptr_t); - if (capacity != current_capacity) { - // Different size, re-allocate - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); + auto current_capacity = hash_map.GetSize() / sizeof(ht_entry_t); + if (capacity > current_capacity) { + // Need more space + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); + entries = reinterpret_cast(hash_map.get()); + } else { + // Just use the current hash map + capacity = current_capacity; } } else { // Allocate a hash map - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(data_ptr_t)); + hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); + entries = reinterpret_cast(hash_map.get()); } - D_ASSERT(hash_map.GetSize() == capacity * sizeof(data_ptr_t)); + D_ASSERT(hash_map.GetSize() == capacity * sizeof(ht_entry_t)); // initialize HT with all-zero entries - std::fill_n(reinterpret_cast(hash_map.get()), capacity, nullptr); + std::fill_n(entries, capacity, ht_entry_t::GetEmptyEntry()); bitmask = capacity - 1; } @@ -316,62 +717,63 @@ void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool para TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, chunk_idx_from, chunk_idx_to, false); const auto row_locations = iterator.GetRowLocations(); + + InsertState insert_state(*this); do { const auto count = iterator.GetCurrentChunkCount(); for (idx_t i = 0; i < count; i++) { hash_data[i] = Load(row_locations[i] + pointer_offset); } - InsertHashes(hashes, count, row_locations, parallel); + TupleDataChunkState &chunk_state = iterator.GetChunkState(); + + InsertHashes(hashes, count, chunk_state, insert_state, parallel); } while (iterator.Next()); } -unique_ptr JoinHashTable::InitializeScanStructure(DataChunk &keys, TupleDataChunkState &key_state, - const SelectionVector *¤t_sel) { +void JoinHashTable::InitializeScanStructure(ScanStructure &scan_structure, DataChunk &keys, + TupleDataChunkState &key_state, const SelectionVector *¤t_sel) { D_ASSERT(Count() > 0); // should be handled before D_ASSERT(finalized); // set up the scan structure - auto ss = make_uniq(*this, key_state); - + scan_structure.is_null = false; + scan_structure.finished = false; if (join_type != JoinType::INNER) { - ss->found_match = make_unsafe_uniq_array(STANDARD_VECTOR_SIZE); - memset(ss->found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); + memset(scan_structure.found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); } // first prepare the keys for probing TupleDataCollection::ToUnifiedFormat(key_state, keys); - ss->count = PrepareKeys(keys, key_state.vector_data, current_sel, ss->sel_vector, false); - return ss; + scan_structure.count = PrepareKeys(keys, key_state.vector_data, current_sel, scan_structure.sel_vector, false); } -unique_ptr JoinHashTable::Probe(DataChunk &keys, TupleDataChunkState &key_state, - Vector *precomputed_hashes) { +void JoinHashTable::Probe(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state, optional_ptr precomputed_hashes) { const SelectionVector *current_sel; - auto ss = InitializeScanStructure(keys, key_state, current_sel); - if (ss->count == 0) { - return ss; + InitializeScanStructure(scan_structure, keys, key_state, current_sel); + if (scan_structure.count == 0) { + return; } if (precomputed_hashes) { - ApplyBitmask(*precomputed_hashes, *current_sel, ss->count, ss->pointers); + GetRowPointers(keys, key_state, probe_state, *precomputed_hashes, *current_sel, scan_structure.count, + scan_structure.pointers, scan_structure.sel_vector); } else { - // hash all the keys Vector hashes(LogicalType::HASH); - Hash(keys, *current_sel, ss->count, hashes); + // hash all the keys + Hash(keys, *current_sel, scan_structure.count, hashes); // now initialize the pointers of the scan structure based on the hashes - ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); + GetRowPointers(keys, key_state, probe_state, hashes, *current_sel, scan_structure.count, + scan_structure.pointers, scan_structure.sel_vector); } - - // create the selection vector linking to only non-empty entries - ss->InitializeSelectionVector(current_sel); - - return ss; } ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state_p) - : key_state(key_state_p), pointers(LogicalType::POINTER), sel_vector(STANDARD_VECTOR_SIZE), ht(ht_p), - finished(false) { + : key_state(key_state_p), pointers(LogicalType::POINTER), count(0), sel_vector(STANDARD_VECTOR_SIZE), + chain_match_sel_vector(STANDARD_VECTOR_SIZE), chain_no_match_sel_vector(STANDARD_VECTOR_SIZE), + found_match(make_unsafe_uniq_array_uninitialized(STANDARD_VECTOR_SIZE)), ht(ht_p), finished(false), + is_null(true) { } void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { @@ -381,8 +783,6 @@ void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { switch (ht.join_type) { case JoinType::INNER: case JoinType::RIGHT: - case JoinType::RIGHT_ANTI: - case JoinType::RIGHT_SEMI: NextInnerJoin(keys, left, result); break; case JoinType::SEMI: @@ -394,6 +794,10 @@ void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { case JoinType::ANTI: NextAntiJoin(keys, left, result); break; + case JoinType::RIGHT_ANTI: + case JoinType::RIGHT_SEMI: + NextRightSemiOrAntiJoin(keys); + break; case JoinType::OUTER: case JoinType::LEFT: NextLeftJoin(keys, left, result); @@ -406,7 +810,7 @@ void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { } } -bool ScanStructure::PointersExhausted() { +bool ScanStructure::PointersExhausted() const { // AdvancePointers creates a "new_count" for every pointer advanced during the // previous advance pointers call. If no pointers are advanced, new_count = 0. // count is then set ot new_count. @@ -414,20 +818,31 @@ bool ScanStructure::PointersExhausted() { } idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { - // Start with the scan selection + + // Initialize the found_match array to the current sel_vector for (idx_t i = 0; i < this->count; ++i) { match_sel.set_index(i, this->sel_vector.get_index(i)); } - idx_t no_match_count = 0; - auto &matcher = no_match_sel ? ht.row_matcher_no_match_sel : ht.row_matcher; - return matcher.Match(keys, key_state.vector_data, match_sel, this->count, ht.layout, pointers, no_match_sel, - no_match_count); + // If there is a matcher for the probing side because of non-equality predicates, use it + if (ht.needs_chain_matcher) { + idx_t no_match_count = 0; + auto &matcher = no_match_sel ? ht.row_matcher_probe_no_match_sel : ht.row_matcher_probe; + D_ASSERT(matcher); + + // we need to only use the vectors with the indices of the columns that are used in the probe phase, namely + // the non-equality columns + return matcher->Match(keys, key_state.vector_data, match_sel, this->count, ht.layout, pointers, no_match_sel, + no_match_count, ht.non_equality_predicate_columns); + } else { + // no match sel is the opposite of match sel + return this->count; + } } idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { while (true) { - // resolve the predicates for this set of keys + // resolve the equality_predicates for this set of keys idx_t result_count = ResolvePredicates(keys, result_vector, nullptr); // after doing all the comparisons set the found_match vector @@ -448,13 +863,19 @@ idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vect } } -void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) { +void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_count) { + + if (!ht.chains_longer_than_one) { + this->count = 0; + return; + } + // now for all the pointers, we move on to the next set of pointers idx_t new_count = 0; auto ptrs = FlatVector::GetData(this->pointers); for (idx_t i = 0; i < sel_count; i++) { auto idx = sel.get_index(i); - ptrs[idx] = Load(ptrs[idx] + ht.pointer_offset); + ptrs[idx] = LoadPointer(ptrs[idx] + ht.pointer_offset); if (ptrs[idx]) { this->sel_vector.set_index(new_count++, idx); } @@ -462,20 +883,6 @@ void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) this->count = new_count; } -void ScanStructure::InitializeSelectionVector(const SelectionVector *¤t_sel) { - idx_t non_empty_count = 0; - auto ptrs = FlatVector::GetData(pointers); - auto cnt = count; - for (idx_t i = 0; i < cnt; i++) { - const auto idx = current_sel->get_index(i); - ptrs[idx] = Load(ptrs[idx]); - if (ptrs[idx]) { - sel_vector.set_index(non_empty_count++, idx); - } - } - count = non_empty_count; -} - void ScanStructure::AdvancePointers() { AdvancePointers(this->sel_vector, this->count); } @@ -499,17 +906,17 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &r return; } - SelectionVector result_vector(STANDARD_VECTOR_SIZE); + idx_t result_count = ScanInnerJoin(keys, chain_match_sel_vector); - idx_t result_count = ScanInnerJoin(keys, result_vector); if (result_count > 0) { if (PropagatesBuildSide(ht.join_type)) { // full/right outer join: mark join matches as FOUND in the HT auto ptrs = FlatVector::GetData(pointers); for (idx_t i = 0; i < result_count; i++) { - auto idx = result_vector.get_index(i); - // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads - // Technically it is, but it does not matter, since the only value that can be written is "true" + auto idx = chain_match_sel_vector.get_index(i); + // NOTE: threadsan reports this as a data race because this can be set concurrently by separate + // threads Technically it is, but it does not matter, since the only value that can be written is + // "true" Store(true, ptrs[idx] + ht.tuple_size); } } @@ -518,14 +925,14 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &r // matches were found // construct the result // on the LHS, we create a slice using the result vector - result.Slice(left, result_vector, result_count); + result.Slice(left, chain_match_sel_vector, result_count); // on the RHS, we need to fetch the data from the hash table for (idx_t i = 0; i < ht.output_columns.size(); i++) { auto &vector = result.data[left.ColumnCount() + i]; const auto output_col_idx = ht.output_columns[i]; D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); - GatherResult(vector, result_vector, result_count, output_col_idx); + GatherResult(vector, chain_match_sel_vector, result_count, output_col_idx); } } AdvancePointers(); @@ -538,18 +945,19 @@ void ScanStructure::ScanKeyMatches(DataChunk &keys) { // we handle the entire chunk in one call to Next(). // for every pointer, we keep chasing pointers and doing comparisons. // this results in a boolean array indicating whether or not the tuple has a match - SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); + // Start with the scan selection + while (this->count > 0) { - // resolve the predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); + // resolve the equality_predicates for the current set of pointers + idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, &chain_no_match_sel_vector); idx_t no_match_count = this->count - match_count; // mark each of the matches as found for (idx_t i = 0; i < match_count; i++) { - found_match[match_sel.get_index(i)] = true; + found_match[chain_match_sel_vector.get_index(i)] = true; } // continue searching for the ones where we did not find a match yet - AdvancePointers(no_match_sel, no_match_count); + AdvancePointers(chain_no_match_sel_vector, no_match_count); } } @@ -594,6 +1002,41 @@ void ScanStructure::NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &re finished = true; } +void ScanStructure::NextRightSemiOrAntiJoin(DataChunk &keys) { + const auto ptrs = FlatVector::GetData(pointers); + while (!PointersExhausted()) { + // resolve the equality_predicates for this set of keys + idx_t result_count = ResolvePredicates(keys, chain_match_sel_vector, nullptr); + + // for each match, fully follow the chain + for (idx_t i = 0; i < result_count; i++) { + const auto idx = chain_match_sel_vector.get_index(i); + auto &ptr = ptrs[idx]; + if (Load(ptr + ht.tuple_size)) { // Early out: chain has been fully marked as found before + ptr = ht.dead_end.get(); + continue; + } + + // Fully mark chain as found + while (true) { + // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads + // Technically it is, but it does not matter, since the only value that can be written is "true" + Store(true, ptr + ht.tuple_size); + auto next_ptr = LoadPointer(ptr + ht.pointer_offset); + if (!next_ptr) { + break; + } + ptr = next_ptr; + } + } + + // check the next set of pointers + AdvancePointers(); + } + + finished = true; +} + void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result) { // for the initial set of columns we just reference the left side result.SetCardinality(child); @@ -637,15 +1080,15 @@ void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &chi } } -void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { - D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); +void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { + D_ASSERT(result.ColumnCount() == left.ColumnCount() + 1); D_ASSERT(result.data.back().GetType() == LogicalType::BOOLEAN); // this method should only be called for a non-empty HT D_ASSERT(ht.Count() > 0); ScanKeyMatches(keys); if (ht.correlated_mark_join_info.correlated_types.empty()) { - ConstructMarkJoinResult(keys, input, result); + ConstructMarkJoinResult(keys, left, result); } else { auto &info = ht.correlated_mark_join_info; lock_guard mj_lock(info.mj_lock); @@ -660,9 +1103,9 @@ void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &r info.correlated_counts->FetchAggregates(info.group_chunk, info.result_chunk); // for the initial set of columns we just reference the left side - result.SetCardinality(input); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); + result.SetCardinality(left); + for (idx_t i = 0; i < left.ColumnCount(); i++) { + result.data[i].Reference(left.data[i]); } // create the result matching vector auto &last_key = keys.data.back(); @@ -674,16 +1117,16 @@ void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &r switch (last_key.GetVectorType()) { case VectorType::CONSTANT_VECTOR: if (ConstantVector::IsNull(last_key)) { - mask.SetAllInvalid(input.size()); + mask.SetAllInvalid(left.size()); } break; case VectorType::FLAT_VECTOR: - mask.Copy(FlatVector::Validity(last_key), input.size()); + mask.Copy(FlatVector::Validity(last_key), left.size()); break; default: { UnifiedVectorFormat kdata; last_key.ToUnifiedFormat(keys.size(), kdata); - for (idx_t i = 0; i < input.size(); i++) { + for (idx_t i = 0; i < left.size(); i++) { auto kidx = kdata.sel->get_index(i); mask.Set(i, kdata.validity.RowIsValid(kidx)); } @@ -694,7 +1137,7 @@ void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &r auto count_star = FlatVector::GetData(info.result_chunk.data[0]); auto count = FlatVector::GetData(info.result_chunk.data[1]); // set the entries to either true or false based on whether a match was found - for (idx_t i = 0; i < input.size(); i++) { + for (idx_t i = 0; i < left.size(); i++) { D_ASSERT(count_star[i] >= count[i]); bool_result[i] = found_match ? found_match[i] : false; if (!bool_result[i] && count_star[i] > count[i]) { @@ -742,39 +1185,40 @@ void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &re } } -void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { +void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { // single join // this join is similar to the semi join except that // (1) we actually return data from the RHS and // (2) we return NULL for that data if there is no match + // (3) if single_join_error_on_multiple_rows is set, we need to keep looking for duplicates after fetching idx_t result_count = 0; SelectionVector result_sel(STANDARD_VECTOR_SIZE); - SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); + while (this->count > 0) { - // resolve the predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, match_sel, &no_match_sel); + // resolve the equality_predicates for the current set of pointers + idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, &chain_no_match_sel_vector); idx_t no_match_count = this->count - match_count; // mark each of the matches as found for (idx_t i = 0; i < match_count; i++) { // found a match for this index - auto index = match_sel.get_index(i); + auto index = chain_match_sel_vector.get_index(i); found_match[index] = true; result_sel.set_index(result_count++, index); } // continue searching for the ones where we did not find a match yet - AdvancePointers(no_match_sel, no_match_count); + AdvancePointers(chain_no_match_sel_vector, no_match_count); } // reference the columns of the left side from the result - D_ASSERT(input.ColumnCount() > 0); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); + D_ASSERT(left.ColumnCount() > 0); + for (idx_t i = 0; i < left.ColumnCount(); i++) { + result.data[i].Reference(left.data[i]); } // now fetch the data from the RHS for (idx_t i = 0; i < ht.output_columns.size(); i++) { - auto &vector = result.data[input.ColumnCount() + i]; + auto &vector = result.data[left.ColumnCount() + i]; // set NULL entries for every entry that was not found - for (idx_t j = 0; j < input.size(); j++) { + for (idx_t j = 0; j < left.size(); j++) { if (!found_match[j]) { FlatVector::SetNull(vector, j, true); } @@ -783,13 +1227,31 @@ void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &input, DataChunk D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); GatherResult(vector, result_sel, result_sel, result_count, output_col_idx); } - result.SetCardinality(input.size()); + result.SetCardinality(left.size()); // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk finished = true; + + if (ht.single_join_error_on_multiple_rows && result_count > 0) { + // we need to throw an error if there are multiple rows per key + // advance pointers for those rows + AdvancePointers(result_sel, result_count); + + // now resolve the predicates + idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, nullptr); + if (match_count > 0) { + // we found at least one duplicate row - throw + throw InvalidInputException( + "More than one row returned by a subquery used as an expression - scalar subqueries can only " + "return a single row.\n\nUse \"SET scalar_subquery_error_on_multiple_rows=false\" to revert to " + "previous behavior of returning a random row."); + } + + this->count = 0; + } } -void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) { +void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) const { // scan the HT starting from the current position and check which rows from the build side did not find a match auto key_locations = FlatVector::GetData(addresses); idx_t found_entries = 0; @@ -899,7 +1361,7 @@ idx_t JoinHashTable::GetTotalSize(const vector &partition_sizes, const ve return total_size + PointerTableSize(total_count); } -idx_t JoinHashTable::GetTotalSize(vector> &local_hts, idx_t &max_partition_size, +idx_t JoinHashTable::GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, idx_t &max_partition_count) const { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); vector partition_sizes(num_partitions, 0); @@ -911,7 +1373,7 @@ idx_t JoinHashTable::GetTotalSize(vector> &local_hts, return GetTotalSize(partition_sizes, partition_counts, max_partition_size, max_partition_count); } -idx_t JoinHashTable::GetRemainingSize() { +idx_t JoinHashTable::GetRemainingSize() const { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); auto &partitions = sink_collection->GetPartitions(); @@ -929,21 +1391,21 @@ void JoinHashTable::Unpartition() { data_collection = sink_collection->GetUnpartitioned(); } -void JoinHashTable::SetRepartitionRadixBits(vector> &local_hts, const idx_t max_ht_size, - const idx_t max_partition_size, const idx_t max_partition_count) { +void JoinHashTable::SetRepartitionRadixBits(const idx_t max_ht_size, const idx_t max_partition_size, + const idx_t max_partition_count) { D_ASSERT(max_partition_size + PointerTableSize(max_partition_count) > max_ht_size); const auto max_added_bits = RadixPartitioning::MAX_RADIX_BITS - radix_bits; idx_t added_bits = 1; for (; added_bits < max_added_bits; added_bits++) { - double partition_multiplier = RadixPartitioning::NumberOfPartitions(added_bits); + double partition_multiplier = static_cast(RadixPartitioning::NumberOfPartitions(added_bits)); - auto new_estimated_size = double(max_partition_size) / partition_multiplier; - auto new_estimated_count = double(max_partition_count) / partition_multiplier; + auto new_estimated_size = static_cast(max_partition_size) / partition_multiplier; + auto new_estimated_count = static_cast(max_partition_count) / partition_multiplier; auto new_estimated_ht_size = - new_estimated_size + static_cast(PointerTableSize(NumericCast(new_estimated_count))); + new_estimated_size + static_cast(PointerTableSize(LossyNumericCast(new_estimated_count))); - if (new_estimated_ht_size <= double(max_ht_size) / 4) { + if (new_estimated_ht_size <= static_cast(max_ht_size) / 4) { // Aim for an estimated partition size of max_ht_size / 4 break; } @@ -963,6 +1425,7 @@ void JoinHashTable::Repartition(JoinHashTable &global_ht) { void JoinHashTable::Reset() { data_collection->Reset(); + hash_map.Reset(); finalized = false; } @@ -1019,10 +1482,9 @@ static void CreateSpillChunk(DataChunk &spill_chunk, DataChunk &keys, DataChunk spill_chunk.data[spill_col_idx].Reference(hashes); } -unique_ptr JoinHashTable::ProbeAndSpill(DataChunk &keys, TupleDataChunkState &key_state, - DataChunk &payload, ProbeSpill &probe_spill, - ProbeSpillLocalAppendState &spill_state, - DataChunk &spill_chunk) { +void JoinHashTable::ProbeAndSpill(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state, DataChunk &payload, ProbeSpill &probe_spill, + ProbeSpillLocalAppendState &spill_state, DataChunk &spill_chunk) { // hash all the keys Vector hashes(LogicalType::HASH); Hash(keys, *FlatVector::IncrementalSelectionVector(), keys.size(), hashes); @@ -1049,18 +1511,14 @@ unique_ptr JoinHashTable::ProbeAndSpill(DataChunk &keys, TupleDat payload.Slice(true_sel, true_count); const SelectionVector *current_sel; - auto ss = InitializeScanStructure(keys, key_state, current_sel); - if (ss->count == 0) { - return ss; + InitializeScanStructure(scan_structure, keys, key_state, current_sel); + if (scan_structure.count == 0) { + return; } // now initialize the pointers of the scan structure based on the hashes - ApplyBitmask(hashes, *current_sel, ss->count, ss->pointers); - - // create the selection vector linking to only non-empty entries - ss->InitializeSelectionVector(current_sel); - - return ss; + GetRowPointers(keys, key_state, probe_state, hashes, *current_sel, scan_structure.count, scan_structure.pointers, + scan_structure.sel_vector); } ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector &probe_types) diff --git a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp index af61fa17..53fe0368 100644 --- a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp +++ b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp @@ -16,13 +16,13 @@ AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_ AggregateObject::AggregateObject(BoundAggregateExpression *aggr) : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), - AlignValue(aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(), - aggr->filter.get()) { + AlignValue(aggr->function.state_size(aggr->function)), aggr->aggr_type, + aggr->return_type.InternalType(), aggr->filter.get()) { } -AggregateObject::AggregateObject(BoundWindowExpression &window) +AggregateObject::AggregateObject(const BoundWindowExpression &window) : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), - AlignValue(window.aggregate->state_size()), + AlignValue(window.aggregate->state_size(*window.aggregate)), window.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT, window.return_type.InternalType(), window.filter_expr.get()) { } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp index a3f8044f..c4cf4b55 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -797,7 +797,6 @@ class HashAggregateGlobalSourceState : public GlobalSourceState { } const PhysicalHashAggregate &op; - mutex lock; atomic state_index; vector> radix_states; @@ -871,7 +870,7 @@ SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataC } // move to the next table - lock_guard l(gstate.lock); + auto guard = gstate.Lock(); lstate.radix_idx = lstate.radix_idx.GetIndex() + 1; if (lstate.radix_idx.GetIndex() > gstate.state_index) { // we have not yet worked on the table @@ -895,26 +894,32 @@ double PhysicalHashAggregate::GetProgress(ClientContext &context, GlobalSourceSt return total_progress / double(groupings.size()); } -string PhysicalHashAggregate::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalHashAggregate::ParamsToString() const { + InsertionOrderPreservingMap result; auto &groups = grouped_aggregate_data.groups; auto &aggregates = grouped_aggregate_data.aggregates; + string groups_info; for (idx_t i = 0; i < groups.size(); i++) { if (i > 0) { - result += "\n"; + groups_info += "\n"; } - result += groups[i]->GetName(); + groups_info += groups[i]->GetName(); } + result["Groups"] = groups_info; + + string aggregate_info; for (idx_t i = 0; i < aggregates.size(); i++) { auto &aggregate = aggregates[i]->Cast(); - if (i > 0 || !groups.empty()) { - result += "\n"; + if (i > 0) { + aggregate_info += "\n"; } - result += aggregates[i]->GetName(); + aggregate_info += aggregates[i]->GetName(); if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); + aggregate_info += " Filter: " + aggregate.filter->GetName(); } } + result["Aggregates"] = aggregate_info; + SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp index fe7e6c46..778e5e67 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -200,24 +200,29 @@ SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context } } -string PhysicalPerfectHashAggregate::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalPerfectHashAggregate::ParamsToString() const { + InsertionOrderPreservingMap result; + string groups_info; for (idx_t i = 0; i < groups.size(); i++) { if (i > 0) { - result += "\n"; + groups_info += "\n"; } - result += groups[i]->GetName(); + groups_info += groups[i]->GetName(); } + result["Groups"] = groups_info; + + string aggregate_info; for (idx_t i = 0; i < aggregates.size(); i++) { - if (i > 0 || !groups.empty()) { - result += "\n"; + if (i > 0) { + aggregate_info += "\n"; } - result += aggregates[i]->GetName(); + aggregate_info += aggregates[i]->GetName(); auto &aggregate = aggregates[i]->Cast(); if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); + aggregate_info += " Filter: " + aggregate.filter->GetName(); } } + result["Aggregates"] = aggregate_info; return result; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp index 9fc58500..903247e7 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/aggregate_function.hpp" #include "duckdb/parallel/thread_context.hpp" @@ -8,30 +9,6 @@ namespace duckdb { -bool PhysicalStreamingWindow::IsStreamingFunction(unique_ptr &expr) { - auto &wexpr = expr->Cast(); - if (!wexpr.partitions.empty() || !wexpr.orders.empty() || wexpr.ignore_nulls || - wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { - return false; - } - switch (wexpr.type) { - // TODO: add more expression types here? - case ExpressionType::WINDOW_AGGREGATE: - // We can stream aggregates if they are "running totals" - // TODO: Support FILTER and DISTINCT - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS && - !wexpr.filter_expr && !wexpr.distinct; - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_ROW_NUMBER: - return true; - default: - return false; - } -} - PhysicalStreamingWindow::PhysicalStreamingWindow(vector types, vector> select_list, idx_t estimated_cardinality, PhysicalOperatorType type) : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { @@ -48,43 +25,253 @@ class StreamingWindowGlobalState : public GlobalOperatorState { class StreamingWindowState : public OperatorState { public: - using StateBuffer = vector; - - StreamingWindowState() - : initialized(false), allocator(Allocator::DefaultAllocator()), - statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)) { - } + struct AggregateState { + AggregateState(ClientContext &client, BoundWindowExpression &wexpr, Allocator &allocator) + : wexpr(wexpr), arena_allocator(Allocator::DefaultAllocator()), executor(client), filter_executor(client), + statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)), hashes(LogicalType::HASH), + addresses(LogicalType::POINTER) { + D_ASSERT(wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE); + auto &aggregate = *wexpr.aggregate; + bind_data = wexpr.bind_info.get(); + dtor = aggregate.destructor; + state.resize(aggregate.state_size(aggregate)); + state_ptr = state.data(); + aggregate.initialize(aggregate, state.data()); + for (auto &child : wexpr.children) { + arg_types.push_back(child->return_type); + executor.AddExpression(*child); + } + if (!arg_types.empty()) { + arg_chunk.Initialize(allocator, arg_types); + arg_cursor.Initialize(allocator, arg_types); + } + if (wexpr.filter_expr) { + filter_executor.AddExpression(*wexpr.filter_expr); + filter_sel.Initialize(); + } + if (wexpr.distinct) { + distinct = make_uniq(client, allocator, arg_types); + distinct_args.Initialize(allocator, arg_types); + distinct_sel.Initialize(); + } + } - ~StreamingWindowState() override { - for (size_t i = 0; i < aggregate_dtors.size(); ++i) { - auto dtor = aggregate_dtors[i]; + ~AggregateState() { if (dtor) { - AggregateInputData aggr_input_data(aggregate_bind_data[i], allocator); - state_ptr = aggregate_states[i].data(); + AggregateInputData aggr_input_data(bind_data, arena_allocator); + state_ptr = state.data(); dtor(statev, aggr_input_data, 1); } } + + void Execute(ExecutionContext &context, DataChunk &input, Vector &result); + + //! The aggregate expression + BoundWindowExpression &wexpr; + //! The allocator to use for aggregate data structures + ArenaAllocator arena_allocator; + //! Reusable executor for the children + ExpressionExecutor executor; + //! Shared executor for FILTER clauses + ExpressionExecutor filter_executor; + //! The single aggregate state we update row-by-row + vector state; + //! The pointer to the state stored in the state vector + data_ptr_t state_ptr = nullptr; + //! The state vector for the single state + Vector statev; + //! The aggregate binding data (if any) + FunctionData *bind_data = nullptr; + //! The aggregate state destructor (if any) + aggregate_destructor_t dtor = nullptr; + //! The inputs rows that pass the FILTER + SelectionVector filter_sel; + //! The number of unfiltered rows so far for COUNT(*) + int64_t unfiltered = 0; + //! Argument types + vector arg_types; + //! Argument value buffer + DataChunk arg_chunk; + //! Argument cursor (a one element slice of arg_chunk) + DataChunk arg_cursor; + + //! Hash table for accumulating the distinct values + unique_ptr distinct; + //! Filtered arguments for checking distinctness + DataChunk distinct_args; + //! Reusable hash vector + Vector hashes; + //! Rows that produced new distinct values + SelectionVector distinct_sel; + //! Pointers to groups in the hash table. + Vector addresses; + }; + + struct LeadLagState { + // Fixed size + static constexpr idx_t MAX_BUFFER = 2048U; + + static bool ComputeOffset(ClientContext &context, BoundWindowExpression &wexpr, int64_t &offset) { + offset = 1; + if (wexpr.offset_expr) { + if (wexpr.offset_expr->HasParameter() || !wexpr.offset_expr->IsFoldable()) { + return false; + } + auto offset_value = ExpressionExecutor::EvaluateScalar(context, *wexpr.offset_expr); + if (offset_value.IsNull()) { + return false; + } + Value bigint_value; + if (!offset_value.DefaultTryCastAs(LogicalType::BIGINT, bigint_value, nullptr, false)) { + return false; + } + offset = bigint_value.GetValue(); + } + + // We can only support LEAD and LAG values within one standard vector + if (wexpr.type == ExpressionType::WINDOW_LEAD) { + offset = -offset; + } + return idx_t(std::abs(offset)) < MAX_BUFFER; + } + + static bool ComputeDefault(ClientContext &context, BoundWindowExpression &wexpr, Value &result) { + if (!wexpr.default_expr) { + result = Value(wexpr.return_type); + return true; + } + + if (wexpr.default_expr && (wexpr.default_expr->HasParameter() || !wexpr.default_expr->IsFoldable())) { + return false; + } + auto dflt_value = ExpressionExecutor::EvaluateScalar(context, *wexpr.default_expr); + return dflt_value.DefaultTryCastAs(wexpr.return_type, result, nullptr, false); + } + + LeadLagState(ClientContext &context, BoundWindowExpression &wexpr) + : wexpr(wexpr), executor(context, *wexpr.children[0]), prev(wexpr.return_type), temp(wexpr.return_type) { + ComputeOffset(context, wexpr, offset); + ComputeDefault(context, wexpr, dflt); + + curr_chunk.Initialize(context, {wexpr.return_type}); + + buffered = idx_t(std::abs(offset)); + prev.Reference(dflt); + prev.Flatten(buffered); + temp.Initialize(false, buffered); + } + + void Execute(ExecutionContext &context, DataChunk &input, DataChunk &delayed, Vector &result) { + if (offset >= 0) { + ExecuteLag(context, input, result); + } else { + ExecuteLead(context, input, delayed, result); + } + } + + void ExecuteLag(ExecutionContext &context, DataChunk &input, Vector &result) { + D_ASSERT(offset >= 0); + auto &curr = curr_chunk.data[0]; + curr_chunk.Reset(); + executor.Execute(input, curr_chunk); + const idx_t count = input.size(); + // Copy prev[0, buffered] => result[0, buffered] + idx_t source_count = MinValue(buffered, count); + VectorOperations::Copy(prev, result, source_count, 0, 0); + // Special case when we have buffered enough values for the output + if (count < buffered) { + // Shift down incomplete buffers + // Copy prev[buffered-count, buffered] => temp[0, count] + source_count = buffered - count; + FlatVector::Validity(temp).Reset(); + VectorOperations::Copy(prev, temp, buffered, source_count, 0); + + // Copy temp[0, count] => prev[0, count] + FlatVector::Validity(prev).Reset(); + VectorOperations::Copy(temp, prev, count, 0, 0); + // Copy curr[0, buffered-count] => prev[count, buffered] + VectorOperations::Copy(curr, prev, source_count, 0, count); + } else { + // Copy input values beyond what we have buffered + source_count = count - buffered; + // Copy curr[0, count-buffered] => result[buffered, count] + VectorOperations::Copy(curr, result, source_count, 0, buffered); + // Copy curr[count-buffered, count] => prev[0, buffered] + FlatVector::Validity(prev).Reset(); + VectorOperations::Copy(curr, prev, count, source_count, 0); + } + } + + void ExecuteLead(ExecutionContext &context, DataChunk &input, DataChunk &delayed, Vector &result) { + // We treat input || delayed as a logical unified buffer + D_ASSERT(offset < 0); + // Input has been set up with the number of rows we CAN produce. + const idx_t count = input.size(); + auto &curr = curr_chunk.data[0]; + // Copy unified[buffered:count] => result[pos:] + idx_t pos = 0; + idx_t unified_offset = buffered; + if (unified_offset < count) { + curr_chunk.Reset(); + executor.Execute(input, curr_chunk); + VectorOperations::Copy(curr, result, count, unified_offset, pos); + pos += count - unified_offset; + unified_offset = count; + } + // Copy unified[unified_offset:] => result[pos:] + idx_t unified_count = count + delayed.size(); + if (unified_offset < unified_count) { + curr_chunk.Reset(); + executor.Execute(delayed, curr_chunk); + idx_t delayed_offset = unified_offset - count; + // Only copy as many values as we need + idx_t delayed_count = MinValue(delayed.size(), delayed_offset + (count - pos)); + VectorOperations::Copy(curr, result, delayed_count, delayed_offset, pos); + pos += delayed_count - delayed_offset; + } + // Copy default[:count-pos] => result[pos:] + if (pos < count) { + const idx_t defaulted = count - pos; + VectorOperations::Copy(prev, result, defaulted, 0, pos); + } + } + + //! The aggregate expression + BoundWindowExpression &wexpr; + //! Cache the executor to cut down on memory allocation + ExpressionExecutor executor; + //! The constant offset + int64_t offset; + //! The number of rows we have buffered + idx_t buffered; + //! The constant default value + Value dflt; + //! The current set of values + DataChunk curr_chunk; + //! The previous set of values + Vector prev; + //! The copy buffer + Vector temp; + }; + + explicit StreamingWindowState(ClientContext &client) : initialized(false), allocator(Allocator::Get(client)) { + } + + ~StreamingWindowState() override { } void Initialize(ClientContext &context, DataChunk &input, const vector> &expressions) { const_vectors.resize(expressions.size()); aggregate_states.resize(expressions.size()); - aggregate_bind_data.resize(expressions.size(), nullptr); - aggregate_dtors.resize(expressions.size(), nullptr); + lead_lag_states.resize(expressions.size()); for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { auto &expr = *expressions[expr_idx]; auto &wexpr = expr.Cast(); switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: { - auto &aggregate = *wexpr.aggregate; - auto &state = aggregate_states[expr_idx]; - aggregate_bind_data[expr_idx] = wexpr.bind_info.get(); - aggregate_dtors[expr_idx] = aggregate.destructor; - state.resize(aggregate.state_size()); - aggregate.initialize(state.data()); + case ExpressionType::WINDOW_AGGREGATE: + aggregate_states[expr_idx] = make_uniq(context, wexpr, allocator); break; - } case ExpressionType::WINDOW_FIRST_VALUE: { // Just execute the expression once ExpressionExecutor executor(context); @@ -105,115 +292,200 @@ class StreamingWindowState : public OperatorState { const_vectors[expr_idx] = make_uniq(Value((int64_t)1)); break; } + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_LEAD: { + lead_lag_states[expr_idx] = make_uniq(context, wexpr); + const auto offset = lead_lag_states[expr_idx]->offset; + if (offset < 0) { + lead_count = MaxValue(idx_t(-offset), lead_count); + } + break; + } default: break; } } + if (lead_count) { + delayed.Initialize(context, input.GetTypes(), lead_count + STANDARD_VECTOR_SIZE); + shifted.Initialize(context, input.GetTypes(), lead_count + STANDARD_VECTOR_SIZE); + } initialized = true; } public: + //! We can't initialise until we have an input chunk bool initialized; + //! The values that are determined by the first row. vector> const_vectors; - ArenaAllocator allocator; - - // Aggregation - vector aggregate_states; - vector aggregate_bind_data; - vector aggregate_dtors; - data_ptr_t state_ptr; - Vector statev; + //! Aggregation states + vector> aggregate_states; + Allocator &allocator; + //! Lead/Lag states + vector> lead_lag_states; + //! The number of rows ahead to buffer for LEAD + idx_t lead_count = 0; + //! A buffer for delayed input + DataChunk delayed; + //! A buffer for shifting delayed input + DataChunk shifted; }; +bool PhysicalStreamingWindow::IsStreamingFunction(ClientContext &context, unique_ptr &expr) { + auto &wexpr = expr->Cast(); + if (!wexpr.partitions.empty() || !wexpr.orders.empty() || wexpr.ignore_nulls || + wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { + return false; + } + switch (wexpr.type) { + // TODO: add more expression types here? + case ExpressionType::WINDOW_AGGREGATE: + // We can stream aggregates if they are "running totals" + return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: + case ExpressionType::WINDOW_ROW_NUMBER: + return true; + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_LEAD: { + // We can stream LEAD/LAG if the arguments are constant and the delta is less than a block behind + Value dflt; + int64_t offset; + return StreamingWindowState::LeadLagState::ComputeDefault(context, wexpr, dflt) && + StreamingWindowState::LeadLagState::ComputeOffset(context, wexpr, offset); + } + default: + return false; + } +} + unique_ptr PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { return make_uniq(); } unique_ptr PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { - return make_uniq(); + return make_uniq(context.client); } -OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); +void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, DataChunk &input, Vector &result) { + // Establish the aggregation environment + const idx_t count = input.size(); + auto &aggregate = *wexpr.aggregate; + auto &aggr_state = *this; + auto &statev = aggr_state.statev; - if (!state.initialized) { - state.Initialize(context.client, input, select_list); + // Compute the FILTER mask (if any) + ValidityMask filter_mask; + auto filtered = count; + auto &filter_sel = aggr_state.filter_sel; + if (wexpr.filter_expr) { + filtered = filter_executor.SelectExpression(input, filter_sel); + if (filtered < count) { + filter_mask.Initialize(count); + filter_mask.SetAllInvalid(count); + for (idx_t f = 0; f < filtered; ++f) { + filter_mask.SetValid(filter_sel.get_index(f)); + } + } } - // Put payload columns in place - for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { - chunk.data[col_idx].Reference(input.data[col_idx]); + + // Check for COUNT(*) + if (wexpr.children.empty()) { + D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); + auto data = FlatVector::GetData(result); + auto &unfiltered = aggr_state.unfiltered; + for (idx_t i = 0; i < count; ++i) { + unfiltered += int64_t(filter_mask.RowIsValid(i)); + data[i] = unfiltered; + } + return; } - // Compute window function - const idx_t count = input.size(); - for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { - idx_t col_idx = input.data.size() + expr_idx; - auto &expr = *select_list[expr_idx]; - auto &result = chunk.data[col_idx]; - switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: { - // Establish the aggregation environment - auto &wexpr = expr.Cast(); - auto &aggregate = *wexpr.aggregate; - auto &statev = state.statev; - state.state_ptr = state.aggregate_states[expr_idx].data(); - AggregateInputData aggr_input_data(wexpr.bind_info.get(), state.allocator); - - // Check for COUNT(*) - if (wexpr.children.empty()) { - D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); - auto data = FlatVector::GetData(result); - int64_t start_row = gstate.row_number; - for (idx_t i = 0; i < input.size(); ++i) { - data[i] = NumericCast(start_row + NumericCast(i)); - } - break; - } - // Compute the arguments - auto &allocator = Allocator::Get(context.client); - ExpressionExecutor executor(context.client); - vector payload_types; - for (auto &child : wexpr.children) { - payload_types.push_back(child->return_type); - executor.AddExpression(*child); - } + // Compute the arguments + auto &arg_chunk = aggr_state.arg_chunk; + executor.Execute(input, arg_chunk); + arg_chunk.Flatten(); - DataChunk payload; - payload.Initialize(allocator, payload_types); - executor.Execute(input, payload); - - // Iterate through them using a single SV - payload.Flatten(); - DataChunk row; - row.Initialize(allocator, payload_types); - sel_t s = 0; - SelectionVector sel(&s); - row.Slice(sel, 1); - // This doesn't work for STRUCTs because the SV - // is not copied to the children when you slice - vector structs; - for (column_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) { - auto &col_vec = row.data[col_idx]; - DictionaryVector::Child(col_vec).Reference(payload.data[col_idx]); - if (col_vec.GetType().InternalType() == PhysicalType::STRUCT) { - structs.emplace_back(col_idx); - } - } + // Update the distinct hash table + ValidityMask distinct_mask; + if (aggr_state.distinct) { + auto &distinct_args = aggr_state.distinct_args; + distinct_args.Reference(arg_chunk); + if (wexpr.filter_expr) { + distinct_args.Slice(filter_sel, filtered); + } + idx_t distinct = 0; + auto &distinct_sel = aggr_state.distinct_sel; + if (filtered) { + // FindOrCreateGroups assumes non-empty input + auto &hashes = aggr_state.hashes; + distinct_args.Hash(hashes); - // Update the state and finalize it one row at a time. - for (idx_t i = 0; i < input.size(); ++i) { - sel.set_index(0, i); - for (const auto struct_idx : structs) { - row.data[struct_idx].Slice(payload.data[struct_idx], sel, 1); - } - // TODO: FILTER and DISTINCT would just skip this. - aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1); - aggregate.finalize(statev, aggr_input_data, result, 1, i); + auto &addresses = aggr_state.addresses; + distinct = aggr_state.distinct->FindOrCreateGroups(distinct_args, hashes, addresses, distinct_sel); + } + + // Translate the distinct selection from filtered row numbers + // back to input row numbers. We need to produce output for all input rows, + // so we filter out + if (distinct < filtered) { + distinct_mask.Initialize(count); + distinct_mask.SetAllInvalid(count); + for (idx_t d = 0; d < distinct; ++d) { + const auto f = distinct_sel.get_index(d); + distinct_mask.SetValid(filter_sel.get_index(f)); } - break; } + } + + // Iterate through them using a single SV + sel_t s = 0; + SelectionVector sel(&s); + auto &arg_cursor = aggr_state.arg_cursor; + arg_cursor.Reset(); + arg_cursor.Slice(sel, 1); + // This doesn't work for STRUCTs because the SV + // is not copied to the children when you slice + vector structs; + for (column_t col_idx = 0; col_idx < arg_chunk.ColumnCount(); ++col_idx) { + auto &col_vec = arg_cursor.data[col_idx]; + DictionaryVector::Child(col_vec).Reference(arg_chunk.data[col_idx]); + if (col_vec.GetType().InternalType() == PhysicalType::STRUCT) { + structs.emplace_back(col_idx); + } + } + + // Update the state and finalize it one row at a time. + AggregateInputData aggr_input_data(wexpr.bind_info.get(), aggr_state.arena_allocator); + for (idx_t i = 0; i < count; ++i) { + sel.set_index(0, i); + for (const auto struct_idx : structs) { + arg_cursor.data[struct_idx].Slice(arg_chunk.data[struct_idx], sel, 1); + } + if (filter_mask.RowIsValid(i) && distinct_mask.RowIsValid(i)) { + aggregate.update(arg_cursor.data.data(), aggr_input_data, arg_cursor.ColumnCount(), statev, 1); + } + aggregate.finalize(statev, aggr_input_data, result, 1, i); + } +} + +void PhysicalStreamingWindow::ExecuteFunctions(ExecutionContext &context, DataChunk &chunk, DataChunk &delayed, + GlobalOperatorState &gstate_p, OperatorState &state_p) const { + auto &gstate = gstate_p.Cast(); + auto &state = state_p.Cast(); + + // Compute window functions + const idx_t count = chunk.size(); + const column_t input_width = children[0]->GetTypes().size(); + for (column_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { + column_t col_idx = input_width + expr_idx; + auto &expr = *select_list[expr_idx]; + auto &result = chunk.data[col_idx]; + switch (expr.GetExpressionType()) { + case ExpressionType::WINDOW_AGGREGATE: + state.aggregate_states[expr_idx]->Execute(context, chunk, result); + break; case ExpressionType::WINDOW_FIRST_VALUE: case ExpressionType::WINDOW_PERCENT_RANK: case ExpressionType::WINDOW_RANK: @@ -231,23 +503,144 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D } break; } + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_LEAD: + state.lead_lag_states[expr_idx]->Execute(context, chunk, delayed, result); + break; default: throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(expr.GetExpressionType())); } } gstate.row_number += NumericCast(count); +} + +void PhysicalStreamingWindow::ExecuteInput(ExecutionContext &context, DataChunk &delayed, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + + // Put payload columns in place + for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { + chunk.data[col_idx].Reference(input.data[col_idx]); + } + idx_t count = input.size(); + + // Handle LEAD + if (state.lead_count > 0) { + // Nothing delayed yet, so just truncate and copy the delayed values + count -= state.lead_count; + input.Copy(delayed, count); + } + chunk.SetCardinality(count); + + ExecuteFunctions(context, chunk, state.delayed, gstate_p, state_p); +} + +void PhysicalStreamingWindow::ExecuteShifted(ExecutionContext &context, DataChunk &delayed, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + auto &shifted = state.shifted; + + idx_t i = input.size(); + idx_t d = delayed.size(); + shifted.Reset(); + // shifted = delayed + delayed.Copy(shifted); + delayed.Reset(); + for (idx_t col_idx = 0; col_idx < delayed.data.size(); ++col_idx) { + // chunk[0:i] = shifted[0:i] + chunk.data[col_idx].Reference(shifted.data[col_idx]); + // delayed[0:i] = chunk[i:d-i] + VectorOperations::Copy(shifted.data[col_idx], delayed.data[col_idx], d, i, 0); + // delayed[d-i:d] = input[0:i] + VectorOperations::Copy(input.data[col_idx], delayed.data[col_idx], i, 0, d - i); + } + chunk.SetCardinality(i); + delayed.SetCardinality(d); + + ExecuteFunctions(context, chunk, delayed, gstate_p, state_p); +} + +void PhysicalStreamingWindow::ExecuteDelayed(ExecutionContext &context, DataChunk &delayed, DataChunk &input, + DataChunk &chunk, GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + // Put payload columns in place + for (idx_t col_idx = 0; col_idx < delayed.data.size(); col_idx++) { + chunk.data[col_idx].Reference(delayed.data[col_idx]); + } + idx_t count = delayed.size(); chunk.SetCardinality(count); - return OperatorResultType::NEED_MORE_INPUT; + + ExecuteFunctions(context, chunk, input, gstate_p, state_p); +} + +OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate_p, OperatorState &state_p) const { + auto &state = state_p.Cast(); + if (!state.initialized) { + state.Initialize(context.client, input, select_list); + } + + auto &delayed = state.delayed; + // We can Reset delayed now that no one can be referencing it. + if (!delayed.size()) { + delayed.Reset(); + } + const idx_t available = delayed.size() + input.size(); + if (available <= state.lead_count) { + // If we don't have enough to produce a single row, + // then just delay more rows, return nothing + // and ask for more data. + delayed.Append(input); + chunk.SetCardinality(0); + return OperatorResultType::NEED_MORE_INPUT; + } else if (input.size() < delayed.size()) { + // If we can't consume all of the delayed values, + // we need to split them instead of referencing them all + ExecuteShifted(context, delayed, input, chunk, gstate_p, state_p); + // We delayed the unused input so ask for more + return OperatorResultType::NEED_MORE_INPUT; + } else if (delayed.size()) { + // We have enough delayed rows so flush them + ExecuteDelayed(context, delayed, input, chunk, gstate_p, state_p); + // Defer resetting delayed as it may be referenced. + delayed.SetCardinality(0); + // Come back to process the input + return OperatorResultType::HAVE_MORE_OUTPUT; + } else { + // No delayed rows, so emit what we can and delay the rest. + ExecuteInput(context, delayed, input, chunk, gstate_p, state_p); + return OperatorResultType::NEED_MORE_INPUT; + } +} + +OperatorFinalizeResultType PhysicalStreamingWindow::FinalExecute(ExecutionContext &context, DataChunk &chunk, + GlobalOperatorState &gstate_p, + OperatorState &state_p) const { + auto &state = state_p.Cast(); + + if (state.initialized && state.lead_count) { + auto &delayed = state.delayed; + // There are no more input rows + auto &input = state.shifted; + input.Reset(); + ExecuteDelayed(context, delayed, input, chunk, gstate_p, state_p); + } + + return OperatorFinalizeResultType::FINISHED; } -string PhysicalStreamingWindow::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalStreamingWindow::ParamsToString() const { + InsertionOrderPreservingMap result; + string projections; for (idx_t i = 0; i < select_list.size(); i++) { if (i > 0) { - result += "\n"; + projections += "\n"; } - result += select_list[i]->GetName(); + projections += select_list[i]->GetName(); } + result["Projections"] = projections; return result; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index eab47dc7..15d082ab 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -15,6 +15,7 @@ #include "duckdb/parallel/executor_task.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" #include @@ -34,94 +35,125 @@ PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(vector types } //===--------------------------------------------------------------------===// -// Sink +// Ungrouped Aggregate State //===--------------------------------------------------------------------===// -struct AggregateState { - explicit AggregateState(const vector> &aggregate_expressions) { - counts = make_uniq_array>(aggregate_expressions.size()); - for (idx_t i = 0; i < aggregate_expressions.size(); i++) { - auto &aggregate = aggregate_expressions[i]; - D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = aggregate->Cast(); - auto state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(state.get()); - aggregates.push_back(std::move(state)); - bind_data.push_back(aggr.bind_info.get()); - destructors.push_back(aggr.function.destructor); +UngroupedAggregateState::UngroupedAggregateState(const vector> &aggregate_expressions) + : aggregate_expressions(aggregate_expressions) { + counts = make_uniq_array>(aggregate_expressions.size()); + for (idx_t i = 0; i < aggregate_expressions.size(); i++) { + auto &aggregate = aggregate_expressions[i]; + D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &aggr = aggregate->Cast(); + auto state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); + aggr.function.initialize(aggr.function, state.get()); + aggregate_data.push_back(std::move(state)); + bind_data.push_back(aggr.bind_info.get()); + destructors.push_back(aggr.function.destructor); #ifdef DEBUG - counts[i] = 0; + counts[i] = 0; #endif - } } - ~AggregateState() { - D_ASSERT(destructors.size() == aggregates.size()); - for (idx_t i = 0; i < destructors.size(); i++) { - if (!destructors[i]) { - continue; - } - Vector state_vector(Value::POINTER(CastPointerToValue(aggregates[i].get()))); - state_vector.SetVectorType(VectorType::FLAT_VECTOR); - - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(bind_data[i], allocator); - destructors[i](state_vector, aggr_input_data, 1); +} +UngroupedAggregateState::~UngroupedAggregateState() { + D_ASSERT(destructors.size() == aggregate_data.size()); + for (idx_t i = 0; i < destructors.size(); i++) { + if (!destructors[i]) { + continue; } - } + Vector state_vector(Value::POINTER(CastPointerToValue(aggregate_data[i].get()))); + state_vector.SetVectorType(VectorType::FLAT_VECTOR); - void Move(AggregateState &other) { - other.aggregates = std::move(aggregates); - other.destructors = std::move(destructors); + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(bind_data[i], allocator); + destructors[i](state_vector, aggr_input_data, 1); } +} - //! The aggregate values - vector> aggregates; - //! The bind data - vector bind_data; - //! The destructors - vector destructors; - //! Counts (used for verification) - unique_array> counts; -}; +void UngroupedAggregateState::Move(UngroupedAggregateState &other) { + other.aggregate_data = std::move(aggregate_data); + other.destructors = std::move(destructors); +} +//===--------------------------------------------------------------------===// +// Global State +//===--------------------------------------------------------------------===// class UngroupedAggregateGlobalSinkState : public GlobalSinkState { public: UngroupedAggregateGlobalSinkState(const PhysicalUngroupedAggregate &op, ClientContext &client) - : state(op.aggregates), finished(false), client_allocator(BufferAllocator::Get(client)), - allocator(client_allocator) { + : state(BufferAllocator::Get(client), op.aggregates), finished(false) { if (op.distinct_data) { distinct_state = make_uniq(*op.distinct_data, client); } } - //! Create an ArenaAllocator with cross-thread lifetime - ArenaAllocator &CreateAllocator() const { - lock_guard glock(lock); - stored_allocators.emplace_back(make_uniq(client_allocator)); - return *stored_allocators.back(); - } - - //! The lock for updating the global aggregate state - mutable mutex lock; //! The global aggregate state - AggregateState state; + GlobalUngroupedAggregateState state; //! Whether or not the aggregate is finished bool finished; //! The data related to the distinct aggregates (if there are any) unique_ptr distinct_state; - //! Client base allocator - Allocator &client_allocator; - //! Global arena allocator - ArenaAllocator allocator; - //! Allocator pool - mutable vector> stored_allocators; }; +ArenaAllocator &GlobalUngroupedAggregateState::CreateAllocator() const { + lock_guard glock(lock); + stored_allocators.emplace_back(make_uniq(client_allocator)); + return *stored_allocators.back(); +} + +void GlobalUngroupedAggregateState::Combine(LocalUngroupedAggregateState &other) { + lock_guard glock(lock); + for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { + auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); + + if (aggregate.IsDistinct()) { + continue; + } + + Vector source_state(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get()))); + Vector dest_state(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); + + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); + aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); +#ifdef DEBUG + state.counts[aggr_idx] += other.state.counts[aggr_idx]; +#endif + } +} + +void GlobalUngroupedAggregateState::CombineDistinct(LocalUngroupedAggregateState &other, + DistinctAggregateData &distinct_data) { + lock_guard glock(lock); + for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { + if (!distinct_data.IsDistinct(aggr_idx)) { + continue; + } + + auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); + + Vector state_vec(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get()))); + Vector combined_vec(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); + aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); +#ifdef DEBUG + state.counts[aggr_idx] += other.state.counts[aggr_idx]; +#endif + } +} + +//===--------------------------------------------------------------------===// +// Local State +//===--------------------------------------------------------------------===// +LocalUngroupedAggregateState::LocalUngroupedAggregateState(GlobalUngroupedAggregateState &gstate) + : allocator(gstate.CreateAllocator()), state(gstate.state.aggregate_expressions) { +} + class UngroupedAggregateLocalSinkState : public LocalSinkState { public: UngroupedAggregateLocalSinkState(const PhysicalUngroupedAggregate &op, const vector &child_types, UngroupedAggregateGlobalSinkState &gstate_p, ExecutionContext &context) - : allocator(gstate_p.CreateAllocator()), state(op.aggregates), child_executor(context.client), - aggregate_input_chunk(), filter_set() { + : state(gstate_p.state), child_executor(context.client), aggregate_input_chunk(), filter_set() { auto &gstate = gstate_p.Cast(); auto &allocator = BufferAllocator::Get(context.client); @@ -145,10 +177,8 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { filter_set.Initialize(context.client, aggregate_objects, child_types); } - //! Local arena allocator - ArenaAllocator &allocator; //! The local aggregate state - AggregateState state; + LocalUngroupedAggregateState state; //! The executor ExpressionExecutor child_executor; //! The payload chunk, containing all the Vectors for the aggregates @@ -189,6 +219,9 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { } }; +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// bool PhysicalUngroupedAggregate::SinkOrderDependent() const { for (auto &expr : aggregates) { auto &aggr = expr->Cast(); @@ -291,10 +324,6 @@ SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataC payload_chunk.SetCardinality(chunk); } -#ifdef DEBUG - sink.state.counts[aggr_idx] += payload_chunk.size(); -#endif - // resolve the child expressions of the aggregate (if any) for (idx_t i = 0; i < aggregate.children.size(); ++i) { sink.child_executor.ExecuteExpression(payload_idx + payload_cnt, @@ -302,14 +331,23 @@ SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataC payload_cnt++; } - auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; - AggregateInputData aggr_input_data(aggregate.bind_info.get(), sink.allocator); - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, - sink.state.aggregates[aggr_idx].get(), payload_chunk.size()); + sink.state.Sink(payload_chunk, payload_idx, aggr_idx); } return SinkResultType::NEED_MORE_INPUT; } +void LocalUngroupedAggregateState::Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx) { +#ifdef DEBUG + state.counts[aggr_idx] += payload_chunk.size(); +#endif + auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); + idx_t payload_cnt = aggregate.children.size(); + auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; + AggregateInputData aggr_input_data(state.bind_data[aggr_idx], allocator); + aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, state.aggregate_data[aggr_idx].get(), + payload_chunk.size()); +} + //===--------------------------------------------------------------------===// // Combine //===--------------------------------------------------------------------===// @@ -344,27 +382,10 @@ SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &cont OperatorSinkCombineInput distinct_input {gstate, lstate, input.interrupt_state}; CombineDistinct(context, distinct_input); - lock_guard glock(gstate.lock); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - if (aggregate.IsDistinct()) { - continue; - } - - Vector source_state(Value::POINTER(CastPointerToValue(lstate.state.aggregates[aggr_idx].get()))); - Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); - - AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); - aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); -#ifdef DEBUG - gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx]; -#endif - } + gstate.state.Combine(lstate.state); auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.child_executor, "child_executor", 0); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); return SinkCombineResultType::FINISHED; @@ -383,6 +404,14 @@ class UngroupedDistinctAggregateFinalizeEvent : public BasePipelineEvent { public: void Schedule() override; + void FinalizeTask() { + lock_guard finalize(lock); + D_ASSERT(!gstate.finished); + D_ASSERT(tasks_done < tasks_scheduled); + if (++tasks_done == tasks_scheduled) { + gstate.finished = true; + } + } private: ClientContext &context; @@ -390,11 +419,11 @@ class UngroupedDistinctAggregateFinalizeEvent : public BasePipelineEvent { const PhysicalUngroupedAggregate &op; UngroupedAggregateGlobalSinkState &gstate; -public: mutex lock; idx_t tasks_scheduled; idx_t tasks_done; +public: vector> global_source_states; }; @@ -403,8 +432,7 @@ class UngroupedDistinctAggregateFinalizeTask : public ExecutorTask { UngroupedDistinctAggregateFinalizeTask(Executor &executor, shared_ptr event_p, const PhysicalUngroupedAggregate &op, UngroupedAggregateGlobalSinkState &state_p) - : ExecutorTask(executor, std::move(event_p)), op(op), gstate(state_p), allocator(gstate.CreateAllocator()), - aggregate_state(op.aggregates) { + : ExecutorTask(executor, std::move(event_p)), op(op), gstate(state_p), aggregate_state(gstate.state) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; @@ -416,10 +444,8 @@ class UngroupedDistinctAggregateFinalizeTask : public ExecutorTask { const PhysicalUngroupedAggregate &op; UngroupedAggregateGlobalSinkState &gstate; - ArenaAllocator &allocator; - // Distinct aggregation state - AggregateState aggregate_state; + LocalUngroupedAggregateState aggregate_state; idx_t aggregation_idx = 0; unique_ptr radix_table_lstate; bool blocked = false; @@ -520,7 +546,6 @@ TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() payload_chunk.InitializeEmpty(distinct_data.grouped_aggregate_data[table_idx]->group_types); payload_chunk.SetCardinality(0); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); while (true) { output_chunk.Reset(); @@ -540,38 +565,15 @@ TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() } payload_chunk.SetCardinality(output_chunk); -#ifdef DEBUG - gstate.state.counts[agg_idx] += payload_chunk.size(); -#endif - // Update the aggregate state - auto start_of_input = payload_cnt ? &payload_chunk.data[0] : nullptr; - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, - state.aggregates[agg_idx].get(), payload_chunk.size()); + state.Sink(payload_chunk, 0, agg_idx); } blocked = false; } // After scanning the distinct HTs, we can combine the thread-local agg states with the thread-global - lock_guard guard(finalize_event.lock); - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - auto &aggregate = aggregates[agg_idx]->Cast(); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); - - Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get()))); - Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get()))); - aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); - } - - D_ASSERT(!gstate.finished); - if (++finalize_event.tasks_done == finalize_event.tasks_scheduled) { - gstate.finished = true; - } + gstate.state.CombineDistinct(state, distinct_data); + finalize_event.FinalizeTask(); return TaskExecutionResult::TASK_FINISHED; } @@ -607,7 +609,8 @@ SinkFinalizeType PhysicalUngroupedAggregate::Finalize(Pipeline &pipeline, Event //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -void VerifyNullHandling(DataChunk &chunk, AggregateState &state, const vector> &aggregates) { +void VerifyNullHandling(DataChunk &chunk, UngroupedAggregateState &state, + const vector> &aggregates) { #ifdef DEBUG for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { auto &aggr = aggregates[aggr_idx]->Cast(); @@ -621,37 +624,43 @@ void VerifyNullHandling(DataChunk &chunk, AggregateState &state, const vectorCast(); + + Vector state_vector(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); + aggregate.function.finalize(state_vector, aggr_input_data, result.data[aggr_idx], 1, 0); + } +} + SourceResultType PhysicalUngroupedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); D_ASSERT(gstate.finished); // initialize the result chunk with the aggregate values - chunk.SetCardinality(1); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - Vector state_vector(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); - aggregate.function.finalize(state_vector, aggr_input_data, chunk.data[aggr_idx], 1, 0); - } - VerifyNullHandling(chunk, gstate.state, aggregates); + gstate.state.Finalize(chunk); + VerifyNullHandling(chunk, gstate.state.state, aggregates); return SourceResultType::FINISHED; } -string PhysicalUngroupedAggregate::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalUngroupedAggregate::ParamsToString() const { + InsertionOrderPreservingMap result; + string aggregate_info; for (idx_t i = 0; i < aggregates.size(); i++) { auto &aggregate = aggregates[i]->Cast(); if (i > 0) { - result += "\n"; + aggregate_info += "\n"; } - result += aggregates[i]->GetName(); + aggregate_info += aggregates[i]->GetName(); if (aggregate.filter) { - result += " Filter: " + aggregate.filter->GetName(); + aggregate_info += " Filter: " + aggregate.filter->GetName(); } } + result["Aggregates"] = aggregate_info; return result; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 7790eb67..e9bf471d 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -30,22 +30,148 @@ namespace duckdb { // Global sink state -class WindowGlobalSinkState : public GlobalSinkState { +class WindowGlobalSinkState; + +enum WindowGroupStage : uint8_t { SINK, FINALIZE, GETDATA, DONE }; + +class WindowHashGroup { public: - WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) - : op(op), mode(DBConfig::GetConfig(context).options.window_mode) { + using HashGroupPtr = unique_ptr; + using OrderMasks = PartitionGlobalHashGroup::OrderMasks; + using ExecutorGlobalStatePtr = unique_ptr; + using ExecutorGlobalStates = vector; + using ExecutorLocalStatePtr = unique_ptr; + using ExecutorLocalStates = vector; + using ThreadLocalStates = vector; + + WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p); - D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[op.order_idx]->Cast(); + ExecutorGlobalStates &Initialize(WindowGlobalSinkState &gstate); - global_partition = - make_uniq(context, wexpr.partitions, wexpr.orders, op.children[0]->types, - wexpr.partitions_stats, op.estimated_cardinality); + // Scan all of the blocks during the build phase + unique_ptr GetBuildScanner(idx_t block_idx) const { + if (!rows) { + return nullptr; + } + return make_uniq(*rows, *heap, layout, external, block_idx, false); } + // Scan a single block during the evaluate phase + unique_ptr GetEvaluateScanner(idx_t block_idx) const { + // Second pass can flush + D_ASSERT(rows); + return make_uniq(*rows, *heap, layout, external, block_idx, true); + } + + // The processing stage for this group + WindowGroupStage GetStage() const { + return stage; + } + + bool TryPrepareNextStage() { + lock_guard prepare_guard(lock); + switch (stage.load()) { + case WindowGroupStage::SINK: + if (sunk == count) { + stage = WindowGroupStage::FINALIZE; + return true; + } + return false; + case WindowGroupStage::FINALIZE: + if (finalized == blocks) { + stage = WindowGroupStage::GETDATA; + return true; + } + return false; + default: + // never block in GETDATA + return true; + } + } + + //! The hash partition data + HashGroupPtr hash_group; + //! The size of the group + idx_t count = 0; + //! The number of blocks in the group + idx_t blocks = 0; + unique_ptr rows; + unique_ptr heap; + RowLayout layout; + //! The partition boundary mask + ValidityMask partition_mask; + //! The order boundary mask + OrderMasks order_masks; + //! External paging + bool external; + // The processing stage for this group + atomic stage; + //! The function global states for this hash group + ExecutorGlobalStates gestates; + //! Executor local states, one per thread + ThreadLocalStates thread_states; + + //! The bin number + idx_t hash_bin; + //! Single threading lock + mutex lock; + //! Count of sunk rows + std::atomic sunk; + //! Count of finalized blocks + std::atomic finalized; + //! The number of tasks left before we should be deleted + std::atomic tasks_remaining; + //! The output ordering batch index this hash group starts at + idx_t batch_base; + +private: + void MaterializeSortedData(); +}; + +class WindowPartitionGlobalSinkState; + +class WindowGlobalSinkState : public GlobalSinkState { +public: + using ExecutorPtr = unique_ptr; + using Executors = vector; + + WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); + + //! Parent operator const PhysicalWindow &op; - unique_ptr global_partition; - WindowAggregationMode mode; + //! Execution context + ClientContext &context; + //! The partitioned sunk data + unique_ptr global_partition; + //! The execution functions + Executors executors; +}; + +class WindowPartitionGlobalSinkState : public PartitionGlobalSinkState { +public: + using WindowHashGroupPtr = unique_ptr; + + WindowPartitionGlobalSinkState(WindowGlobalSinkState &gsink, const BoundWindowExpression &wexpr) + : PartitionGlobalSinkState(gsink.context, wexpr.partitions, wexpr.orders, gsink.op.children[0]->types, + wexpr.partitions_stats, gsink.op.estimated_cardinality), + gsink(gsink) { + } + ~WindowPartitionGlobalSinkState() override = default; + + void OnBeginMerge() override { + PartitionGlobalSinkState::OnBeginMerge(); + window_hash_groups.resize(hash_groups.size()); + } + + void OnSortedPartition(const idx_t group_idx) override { + PartitionGlobalSinkState::OnSortedPartition(group_idx); + window_hash_groups[group_idx] = make_uniq(gsink, group_idx); + } + + //! Operator global sink state + WindowGlobalSinkState &gsink; + //! The sorted hash groups + vector window_hash_groups; }; // Per-thread sink state @@ -89,39 +215,54 @@ PhysicalWindow::PhysicalWindow(vector types, vector WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &context, - const ValidityMask &partition_mask, - const ValidityMask &order_mask, const idx_t payload_count, WindowAggregationMode mode) { switch (wexpr.type) { case ExpressionType::WINDOW_AGGREGATE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, mode); + return make_uniq(wexpr, context, mode); case ExpressionType::WINDOW_ROW_NUMBER: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_RANK_DENSE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_RANK: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_PERCENT_RANK: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_CUME_DIST: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_NTILE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_LEAD: case ExpressionType::WINDOW_LAG: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_FIRST_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_LAST_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); case ExpressionType::WINDOW_NTH_VALUE: - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); + return make_uniq(wexpr, context); break; default: throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.type)); } } +WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) + : op(op), context(context) { + + D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.select_list[op.order_idx]->Cast(); + + const auto mode = DBConfig::GetConfig(context).options.window_mode; + for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { + D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); + auto &wexpr = op.select_list[expr_idx]->Cast(); + auto wexec = WindowExecutorFactory(wexpr, context, mode); + executors.emplace_back(std::move(wexec)); + } + + global_partition = make_uniq(*this, wexpr); +} + //===--------------------------------------------------------------------===// // Sink //===--------------------------------------------------------------------===// @@ -171,7 +312,7 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(*state.global_partition, pipeline); + auto new_event = make_shared_ptr(*state.global_partition, pipeline, *this); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; @@ -180,123 +321,156 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -class WindowPartitionSourceState; - class WindowGlobalSourceState : public GlobalSourceState { public: - using HashGroupSourcePtr = unique_ptr; using ScannerPtr = unique_ptr; - using Task = std::pair; + + struct Task { + Task(WindowGroupStage stage, idx_t group_idx, idx_t max_idx) + : stage(stage), group_idx(group_idx), thread_idx(0), max_idx(max_idx) { + } + WindowGroupStage stage; + //! The hash group + idx_t group_idx; + //! The thread index (for local state) + idx_t thread_idx; + //! The total block index count + idx_t max_idx; + //! The first block index count + idx_t begin_idx = 0; + //! The end block index count + idx_t end_idx = 0; + }; + using TaskPtr = optional_ptr; WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p); - //! Get the next task - Task NextTask(idx_t hash_bin); + //! Build task list + void CreateTaskList(); + + //! Are there any more tasks? + bool HasMoreTasks() const { + return !stopped && next_task < tasks.size(); + } + bool HasUnfinishedTasks() const { + return !stopped && finished < tasks.size(); + } + //! Try to advance the group stage + bool TryPrepareNextStage(); + //! Get the next task given the current state + bool TryNextTask(TaskPtr &task); + //! Finish a task + void FinishTask(TaskPtr task); //! Context for executing computations ClientContext &context; //! All the sunk data WindowGlobalSinkState &gsink; - //! The next group to build. - atomic next_build; - //! The built groups - vector built; - //! Serialise access to the built hash groups - mutable mutex built_lock; - //! The number of unfinished tasks - atomic tasks_remaining; + //! The total number of blocks to process; + idx_t total_blocks = 0; + //! The number of local states + atomic locals; + //! The list of tasks + vector tasks; + //! The the next task + atomic next_task; + //! The the number of finished tasks + atomic finished; + //! Stop producing tasks + atomic stopped; //! The number of rows returned atomic returned; public: idx_t MaxThreads() override { - return tasks_remaining; + return total_blocks; } - -private: - Task CreateTask(idx_t hash_bin); - Task StealWork(); }; WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) - : context(context_p), gsink(gsink_p), next_build(0), tasks_remaining(0), returned(0) { - auto &hash_groups = gsink.global_partition->hash_groups; - + : context(context_p), gsink(gsink_p), locals(0), next_task(0), finished(0), stopped(false), returned(0) { auto &gpart = gsink.global_partition; - if (hash_groups.empty()) { + auto &window_hash_groups = gsink.global_partition->window_hash_groups; + + if (window_hash_groups.empty()) { // OVER() - built.resize(1); - if (gpart->rows) { - tasks_remaining += gpart->rows->blocks.size(); + if (gpart->rows && !gpart->rows->blocks.empty()) { + // We need to construct the single WindowHashGroup here because the sort tasks will not be run. + window_hash_groups.emplace_back(make_uniq(gsink, idx_t(0))); + total_blocks = gpart->rows->blocks.size(); } } else { - built.resize(hash_groups.size()); idx_t batch_base = 0; - for (auto &hash_group : hash_groups) { - if (!hash_group) { + for (auto &window_hash_group : window_hash_groups) { + if (!window_hash_group) { continue; } - auto &global_sort_state = *hash_group->global_sort; - if (global_sort_state.sorted_blocks.empty()) { + auto &rows = window_hash_group->rows; + if (!rows) { continue; } - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - auto &sb = *global_sort_state.sorted_blocks[0]; - auto &sd = *sb.payload_data; - tasks_remaining += sd.data_blocks.size(); - - hash_group->batch_base = batch_base; - batch_base += sd.data_blocks.size(); + const auto block_count = window_hash_group->rows->blocks.size(); + window_hash_group->batch_base = batch_base; + batch_base += block_count; } + total_blocks = batch_base; } } -// Per-bin evaluation state (build and evaluate) -class WindowPartitionSourceState { -public: - using HashGroupPtr = unique_ptr; - using ExecutorPtr = unique_ptr; - using Executors = vector; - using OrderMasks = PartitionGlobalHashGroup::OrderMasks; - - WindowPartitionSourceState(ClientContext &context, WindowGlobalSourceState &gsource) - : context(context), op(gsource.gsink.op), gsource(gsource), read_block_idx(0), unscanned(0) { - layout.Initialize(gsource.gsink.global_partition->payload_types); +void WindowGlobalSourceState::CreateTaskList() { + // Check whether we have a task list outside the mutex. + if (next_task.load()) { + return; } - unique_ptr GetScanner() const; - void MaterializeSortedData(); - void BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin); + auto guard = Lock(); - ClientContext &context; - const PhysicalWindow &op; - WindowGlobalSourceState &gsource; - - HashGroupPtr hash_group; - //! The generated input chunks - unique_ptr rows; - unique_ptr heap; - RowLayout layout; - //! The partition boundary mask - ValidityMask partition_mask; - //! The order boundary mask - OrderMasks order_masks; - //! External paging - bool external; - //! The current execution functions - Executors executors; + auto &window_hash_groups = gsink.global_partition->window_hash_groups; + if (!tasks.empty()) { + return; + } - //! The bin number - idx_t hash_bin; + // Sort the groups from largest to smallest + if (window_hash_groups.empty()) { + return; + } - //! The next block to read. - mutable atomic read_block_idx; - //! The number of remaining unscanned blocks. - atomic unscanned; -}; + using PartitionBlock = std::pair; + vector partition_blocks; + for (idx_t group_idx = 0; group_idx < window_hash_groups.size(); ++group_idx) { + auto &window_hash_group = window_hash_groups[group_idx]; + partition_blocks.emplace_back(window_hash_group->rows->blocks.size(), group_idx); + } + std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); + + // Schedule the largest group on as many threads as possible + const auto threads = locals.load(); + const auto &max_block = partition_blocks.front(); + const auto per_thread = (max_block.first + threads - 1) / threads; + if (!per_thread) { + throw InternalException("No blocks per thread! %ld threads, %ld groups, %ld blocks, %ld hash group", threads, + partition_blocks.size(), max_block.first, max_block.second); + } + + // TODO: Generate dynamically instead of building a big list? + vector states {WindowGroupStage::SINK, WindowGroupStage::FINALIZE, WindowGroupStage::GETDATA}; + for (const auto &b : partition_blocks) { + auto &window_hash_group = *window_hash_groups[b.second]; + for (const auto &state : states) { + idx_t thread_count = 0; + for (Task task(state, b.second, b.first); task.begin_idx < task.max_idx; task.begin_idx += per_thread) { + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + tasks.emplace_back(task); + window_hash_group.tasks_remaining++; + thread_count = ++task.thread_idx; + } + window_hash_group.thread_states.resize(thread_count); + } + } +} -void WindowPartitionSourceState::MaterializeSortedData() { +void WindowHashGroup::MaterializeSortedData() { auto &global_sort_state = *hash_group->global_sort; if (global_sort_state.sorted_blocks.empty()) { return; @@ -329,38 +503,23 @@ void WindowPartitionSourceState::MaterializeSortedData() { heap->blocks = std::move(sd.heap_blocks); hash_group.reset(); } else { - heap = make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); + heap = make_uniq(buffer_manager, buffer_manager.GetBlockSize(), 1U, true); } heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); } -unique_ptr WindowPartitionSourceState::GetScanner() const { - auto &gsink = *gsource.gsink.global_partition; - if ((gsink.rows && !hash_bin) || hash_bin < gsink.hash_groups.size()) { - const auto block_idx = read_block_idx++; - if (block_idx >= rows->blocks.size()) { - return nullptr; - } - // Second pass can flush - --gsource.tasks_remaining; - return make_uniq(*rows, *heap, layout, external, block_idx, true); - } - return nullptr; -} - -void WindowPartitionSourceState::BuildPartition(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) { - // Get rid of any stale data - hash_bin = hash_bin_p; - +WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) + : count(0), blocks(0), stage(WindowGroupStage::SINK), hash_bin(hash_bin_p), sunk(0), finalized(0), + tasks_remaining(0), batch_base(0) { // There are three types of partitions: // 1. No partition (no sorting) // 2. One partition (sorting, but no hashing) // 3. Multiple partitions (sorting and hashing) // How big is the partition? - auto &gpart = *gsource.gsink.global_partition; - idx_t count = 0; + auto &gpart = *gstate.global_partition; + layout.Initialize(gpart.payload_types); if (hash_bin < gpart.hash_groups.size() && gpart.hash_groups[hash_bin]) { count = gpart.hash_groups[hash_bin]->count; } else if (gpart.rows && !hash_bin) { @@ -373,9 +532,9 @@ void WindowPartitionSourceState::BuildPartition(WindowGlobalSinkState &gstate, c partition_mask.Initialize(count); partition_mask.SetAllInvalid(count); - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); + const auto &executors = gstate.executors; + for (auto &wexec : executors) { + auto &wexpr = wexec->wexpr; auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; if (order_mask.IsMaskSet()) { continue; @@ -404,248 +563,263 @@ void WindowPartitionSourceState::BuildPartition(WindowGlobalSinkState &gstate, c hash_group->ComputeMasks(partition_mask, order_masks); external = hash_group->global_sort->external; MaterializeSortedData(); - } else { - return; - } - - // Create the executors for each function - executors.clear(); - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); - auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; - auto wexec = WindowExecutorFactory(wexpr, context, partition_mask, order_mask, count, gstate.mode); - executors.emplace_back(std::move(wexec)); - } - - // First pass over the input without flushing - DataChunk input_chunk; - input_chunk.Initialize(gpart.allocator, gpart.payload_types); - auto scanner = make_uniq(*rows, *heap, layout, external, false); - idx_t input_idx = 0; - while (true) { - input_chunk.Reset(); - scanner->Scan(input_chunk); - if (input_chunk.size() == 0) { - break; - } - - // TODO: Parallelization opportunity - for (auto &wexec : executors) { - wexec->Sink(input_chunk, input_idx, scanner->Count()); - } - input_idx += input_chunk.size(); } - // TODO: Parallelization opportunity - for (auto &wexec : executors) { - wexec->Finalize(); + if (rows) { + blocks = rows->blocks.size(); } - - // External scanning assumes all blocks are swizzled. - scanner->ReSwizzle(); - - // Start the block countdown - unscanned = rows->blocks.size(); } // Per-thread scan state class WindowLocalSourceState : public LocalSourceState { public: - using ReadStatePtr = unique_ptr; - using ReadStates = vector; + using Task = WindowGlobalSourceState::Task; + using TaskPtr = optional_ptr; explicit WindowLocalSourceState(WindowGlobalSourceState &gsource); - void UpdateBatchIndex(); - bool NextPartition(); - void Scan(DataChunk &chunk); + + //! Does the task have more work to do? + bool TaskFinished() const { + return !task || task->begin_idx == task->end_idx; + } + //! Assign the next task + bool TryAssignTask(); + //! Execute a step in the current task + void ExecuteTask(DataChunk &chunk); //! The shared source state WindowGlobalSourceState &gsource; - //! The current bin being processed - idx_t hash_bin; //! The current batch index (for output reordering) idx_t batch_index; + //! The task this thread is working on + TaskPtr task; //! The current source being processed - optional_ptr partition_source; - //! The read cursor + optional_ptr window_hash_group; + //! The scan cursor unique_ptr scanner; //! Buffer for the inputs DataChunk input_chunk; - //! Executor read states. - ReadStates read_states; //! Buffer for window results DataChunk output_chunk; -}; -WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) - : gsource(gsource), hash_bin(gsource.built.size()), batch_index(0) { - auto &gsink = *gsource.gsink.global_partition; - auto &op = gsource.gsink.op; +protected: + void Sink(); + void Finalize(); + void GetData(DataChunk &chunk); +}; - input_chunk.Initialize(gsink.allocator, gsink.payload_types); +WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::Initialize(WindowGlobalSinkState &gsink) { + // Single-threaded building as this is mostly memory allocation + lock_guard gestate_guard(lock); + const auto &executors = gsink.executors; + if (gestates.size() == executors.size()) { + return gestates; + } - vector output_types; - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); - output_types.emplace_back(wexpr.return_type); + // These can be large so we defer building them until we are ready. + for (auto &wexec : executors) { + auto &wexpr = wexec->wexpr; + auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; + gestates.emplace_back(wexec->GetGlobalState(count, partition_mask, order_mask)); } - output_chunk.Initialize(Allocator::Get(gsource.context), output_types); + + return gestates; } -WindowGlobalSourceState::Task WindowGlobalSourceState::CreateTask(idx_t hash_bin) { - // Build outside the lock so no one tries to steal before we are done. - auto partition_source = make_uniq(context, *this); - partition_source->BuildPartition(gsink, hash_bin); - Task result(partition_source.get(), partition_source->GetScanner()); +void WindowLocalSourceState::Sink() { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::SINK); - // Is there any data to scan? - if (result.second) { - lock_guard built_guard(built_lock); - built[hash_bin] = std::move(partition_source); + auto &gsink = gsource.gsink; + const auto &executors = gsink.executors; - return result; - } + // Create the global state for each function + // These can be large so we defer building them until we are ready. + auto &gestates = window_hash_group->Initialize(gsink); - return Task(); -} + // Set up the local states + auto &local_states = window_hash_group->thread_states.at(task->thread_idx); + if (local_states.empty()) { + for (idx_t w = 0; w < executors.size(); ++w) { + local_states.emplace_back(executors[w]->GetLocalState(*gestates[w])); + } + } -WindowGlobalSourceState::Task WindowGlobalSourceState::StealWork() { - for (idx_t hash_bin = 0; hash_bin < built.size(); ++hash_bin) { - lock_guard built_guard(built_lock); - auto &partition_source = built[hash_bin]; - if (!partition_source) { - continue; + // First pass over the input without flushing + for (; task->begin_idx < task->end_idx; ++task->begin_idx) { + scanner = window_hash_group->GetBuildScanner(task->begin_idx); + if (!scanner) { + break; } + while (true) { + // TODO: Try to align on validity mask boundaries by starting ragged? + idx_t input_idx = scanner->Scanned(); + input_chunk.Reset(); + scanner->Scan(input_chunk); + if (input_chunk.size() == 0) { + break; + } - Task result(partition_source.get(), partition_source->GetScanner()); + for (idx_t w = 0; w < executors.size(); ++w) { + executors[w]->Sink(input_chunk, input_idx, window_hash_group->count, *gestates[w], *local_states[w]); + } - // Is there any data to scan? - if (result.second) { - return result; + window_hash_group->sunk += input_chunk.size(); } + + // External scanning assumes all blocks are swizzled. + scanner->SwizzleBlock(task->begin_idx); + scanner.reset(); + } +} + +void WindowLocalSourceState::Finalize() { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::FINALIZE); + + // Finalize all the executors. + // Parallel finalisation is handled internally by the executor, + // and should not return until all threads have completed work. + auto &gsink = gsource.gsink; + const auto &executors = gsink.executors; + auto &gestates = window_hash_group->gestates; + auto &local_states = window_hash_group->thread_states.at(task->thread_idx); + for (idx_t w = 0; w < executors.size(); ++w) { + executors[w]->Finalize(*gestates[w], *local_states[w]); } - // Nothing to steal - return Task(); + // Mark this range as done + window_hash_group->finalized += (task->end_idx - task->begin_idx); + task->begin_idx = task->end_idx; } -WindowGlobalSourceState::Task WindowGlobalSourceState::NextTask(idx_t hash_bin) { - auto &hash_groups = gsink.global_partition->hash_groups; - const auto bin_count = built.size(); - - // Flush unneeded data - if (hash_bin < bin_count) { - // Lock and delete when all blocks have been scanned - // We do this here instead of in NextScan so the WindowLocalSourceState - // has a chance to delete its state objects first, - // which may reference the partition_source - - // Delete data outside the lock in case it is slow - HashGroupSourcePtr killed; - lock_guard built_guard(built_lock); - auto &partition_source = built[hash_bin]; - if (partition_source && !partition_source->unscanned) { - killed = std::move(partition_source); - } +WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) : gsource(gsource), batch_index(0) { + auto &gsink = gsource.gsink; + auto &global_partition = *gsink.global_partition; + + input_chunk.Initialize(global_partition.allocator, global_partition.payload_types); + + vector output_types; + for (auto &wexec : gsink.executors) { + auto &wexpr = wexec->wexpr; + output_types.emplace_back(wexpr.return_type); } + output_chunk.Initialize(global_partition.allocator, output_types); - hash_bin = next_build++; - if (hash_bin < bin_count) { - // Find a non-empty hash group. - for (; hash_bin < hash_groups.size(); hash_bin = next_build++) { - if (hash_groups[hash_bin] && hash_groups[hash_bin]->count) { - auto result = CreateTask(hash_bin); - if (result.second) { - return result; - } - } - } + ++gsource.locals; +} - // OVER() doesn't have a hash_group - if (hash_groups.empty()) { - auto result = CreateTask(hash_bin); - if (result.second) { - return result; - } - } +bool WindowGlobalSourceState::TryNextTask(TaskPtr &task) { + auto guard = Lock(); + if (next_task >= tasks.size() || stopped) { + task = nullptr; + return false; } - // Work stealing - while (!context.interrupted && tasks_remaining) { - auto result = StealWork(); - if (result.second) { - return result; - } + // If the next task matches the current state of its group, then we can use it + // Otherwise block. + task = &tasks[next_task]; + + auto &gpart = *gsink.global_partition; + auto &window_hash_group = gpart.window_hash_groups[task->group_idx]; + auto group_stage = window_hash_group->GetStage(); - // If there is nothing to steal but there are unfinished partitions, - // yield until any pending builds are done. - TaskScheduler::YieldThread(); + if (task->stage == group_stage) { + ++next_task; + return true; } - return Task(); + task = nullptr; + return false; } -void WindowLocalSourceState::UpdateBatchIndex() { - D_ASSERT(partition_source); - D_ASSERT(scanner.get()); +void WindowGlobalSourceState::FinishTask(TaskPtr task) { + if (!task) { + return; + } + + auto &gpart = *gsink.global_partition; + auto &finished_hash_group = gpart.window_hash_groups[task->group_idx]; + D_ASSERT(finished_hash_group); - batch_index = partition_source->hash_group ? partition_source->hash_group->batch_base : 0; - batch_index += scanner->BlockIndex(); + if (!--finished_hash_group->tasks_remaining) { + finished_hash_group.reset(); + } } -bool WindowLocalSourceState::NextPartition() { - // Release old states before the source +bool WindowLocalSourceState::TryAssignTask() { + // Because downstream operators may be using our internal buffers, + // we can't "finish" a task until we are about to get the next one. + + // Scanner first, as it may be referencing sort blocks in the hash group scanner.reset(); - read_states.clear(); + gsource.FinishTask(task); - // Get a partition_source that is not finished - while (!scanner) { - auto task = gsource.NextTask(hash_bin); - if (!task.first) { - return false; - } - partition_source = task.first; - scanner = std::move(task.second); - hash_bin = partition_source->hash_bin; - UpdateBatchIndex(); - } + return gsource.TryNextTask(task); +} - for (auto &wexec : partition_source->executors) { - read_states.emplace_back(wexec->GetExecutorState()); +bool WindowGlobalSourceState::TryPrepareNextStage() { + if (next_task >= tasks.size() || stopped) { + return true; } - return true; + auto task = &tasks[next_task]; + auto window_hash_group = gsink.global_partition->window_hash_groups[task->group_idx].get(); + return window_hash_group->TryPrepareNextStage(); } -void WindowLocalSourceState::Scan(DataChunk &result) { - D_ASSERT(scanner); - if (!scanner->Remaining()) { - lock_guard built_guard(gsource.built_lock); - --partition_source->unscanned; - scanner = partition_source->GetScanner(); +void WindowLocalSourceState::ExecuteTask(DataChunk &result) { + auto &gsink = gsource.gsink; - if (!scanner) { - partition_source = nullptr; - read_states.clear(); - return; - } + // Update the hash group + window_hash_group = gsink.global_partition->window_hash_groups[task->group_idx].get(); + + // Process the new state + switch (task->stage) { + case WindowGroupStage::SINK: + Sink(); + D_ASSERT(TaskFinished()); + break; + case WindowGroupStage::FINALIZE: + Finalize(); + D_ASSERT(TaskFinished()); + break; + case WindowGroupStage::GETDATA: + D_ASSERT(!TaskFinished()); + GetData(result); + break; + default: + throw InternalException("Invalid window source state."); + } - UpdateBatchIndex(); + // Count this task as finished. + if (TaskFinished()) { + ++gsource.finished; + } +} + +void WindowLocalSourceState::GetData(DataChunk &result) { + D_ASSERT(window_hash_group->GetStage() == WindowGroupStage::GETDATA); + + if (!scanner || !scanner->Remaining()) { + scanner = window_hash_group->GetEvaluateScanner(task->begin_idx); + batch_index = window_hash_group->batch_base + task->begin_idx; } const auto position = scanner->Scanned(); input_chunk.Reset(); scanner->Scan(input_chunk); - auto &executors = partition_source->executors; + const auto &executors = gsource.gsink.executors; + auto &gestates = window_hash_group->gestates; + auto &local_states = window_hash_group->thread_states.at(task->thread_idx); output_chunk.Reset(); for (idx_t expr_idx = 0; expr_idx < executors.size(); ++expr_idx) { auto &executor = *executors[expr_idx]; - auto &lstate = *read_states[expr_idx]; + auto &gstate = *gestates[expr_idx]; + auto &lstate = *local_states[expr_idx]; auto &result = output_chunk.data[expr_idx]; - executor.Evaluate(position, input_chunk, result, lstate); + executor.Evaluate(position, input_chunk, result, lstate, gstate); } output_chunk.SetCardinality(input_chunk); output_chunk.Verify(); @@ -658,6 +832,16 @@ void WindowLocalSourceState::Scan(DataChunk &result) { for (idx_t col_idx = 0; col_idx < output_chunk.ColumnCount(); col_idx++) { result.data[out_idx++].Reference(output_chunk.data[col_idx]); } + + // If we done with this block, move to the next one + if (!scanner->Remaining()) { + ++task->begin_idx; + } + + // If that was the last block, release out local state memory. + if (TaskFinished()) { + local_states.clear(); + } result.Verify(); } @@ -676,11 +860,22 @@ bool PhysicalWindow::SupportsBatchIndex() const { // We can only preserve order for single partitioning // or work stealing causes out of order batch numbers auto &wexpr = select_list[order_idx]->Cast(); - return wexpr.partitions.empty() && !wexpr.orders.empty(); + return wexpr.partitions.empty(); // NOLINT } OrderPreservationType PhysicalWindow::SourceOrder() const { - return SupportsBatchIndex() ? OrderPreservationType::FIXED_ORDER : OrderPreservationType::NO_ORDER; + auto &wexpr = select_list[order_idx]->Cast(); + if (!wexpr.partitions.empty()) { + // if we have partitions the window order is not defined + return OrderPreservationType::NO_ORDER; + } + // without partitions we can maintain order + if (wexpr.orders.empty()) { + // if we have no orders we maintain insertion order + return OrderPreservationType::INSERTION_ORDER; + } + // otherwise we can maintain the fixed order + return OrderPreservationType::FIXED_ORDER; } double PhysicalWindow::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { @@ -689,7 +884,7 @@ double PhysicalWindow::GetProgress(ClientContext &context, GlobalSourceState &gs auto &gsink = gsource.gsink; const auto count = gsink.global_partition->count.load(); - return count ? (returned / double(count)) : -1; + return count ? (double(returned) / double(count)) : -1; } idx_t PhysicalWindow::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, @@ -702,29 +897,53 @@ SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &c OperatorSourceInput &input) const { auto &gsource = input.global_state.Cast(); auto &lsource = input.local_state.Cast(); - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner) { - if (!lsource.NextPartition()) { - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + + gsource.CreateTaskList(); + + while (gsource.HasUnfinishedTasks() && chunk.size() == 0) { + if (!lsource.TaskFinished() || lsource.TryAssignTask()) { + try { + lsource.ExecuteTask(chunk); + } catch (...) { + gsource.stopped = true; + throw; + } + } else { + auto guard = gsource.Lock(); + if (!gsource.HasMoreTasks()) { + // no more tasks - exit + gsource.UnblockTasks(guard); + break; + } + if (gsource.TryPrepareNextStage()) { + // we successfully prepared the next stage - unblock tasks + gsource.UnblockTasks(guard); + } else { + // there are more tasks available, but we can't execute them yet + // block the source + return gsource.BlockSource(guard, input.interrupt_state); } } - - lsource.Scan(chunk); - gsource.returned += chunk.size(); } - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + gsource.returned += chunk.size(); + + if (chunk.size() == 0) { + return SourceResultType::FINISHED; + } + return SourceResultType::HAVE_MORE_OUTPUT; } -string PhysicalWindow::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalWindow::ParamsToString() const { + InsertionOrderPreservingMap result; + string projections; for (idx_t i = 0; i < select_list.size(); i++) { if (i > 0) { - result += "\n"; + projections += "\n"; } - result += select_list[i]->GetName(); + projections += select_list[i]->GetName(); } + result["Projections"] = projections; return result; } diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp index 93402542..bf0d7aeb 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp @@ -53,7 +53,7 @@ shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_s void CSVBuffer::AllocateBuffer(idx_t buffer_size) { auto &buffer_manager = BufferManager::GetBufferManager(context); bool can_destroy = !is_pipe; - handle = buffer_manager.Allocate(MemoryTag::CSV_READER, MaxValue(Storage::BLOCK_SIZE, buffer_size), + handle = buffer_manager.Allocate(MemoryTag::CSV_READER, MaxValue(buffer_manager.GetBlockSize(), buffer_size), can_destroy, &block); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp index 948e91cf..c8a7d167 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp @@ -4,8 +4,9 @@ namespace duckdb { CSVBufferManager::CSVBufferManager(ClientContext &context_p, const CSVReaderOptions &options, const string &file_path_p, - const idx_t file_idx_p) - : context(context_p), file_idx(file_idx_p), file_path(file_path_p), buffer_size(CSVBuffer::CSV_BUFFER_SIZE) { + const idx_t file_idx_p, bool per_file_single_threaded_p) + : context(context_p), per_file_single_threaded(per_file_single_threaded_p), file_idx(file_idx_p), + file_path(file_path_p), buffer_size(CSVBuffer::CSV_BUFFER_SIZE) { D_ASSERT(!file_path.empty()); file_handle = ReadCSV::OpenCSV(file_path, options.compression, context); is_pipe = file_handle->IsPipe(); @@ -71,7 +72,7 @@ shared_ptr CSVBufferManager::GetBuffer(const idx_t pos) { done = true; } } - if (pos != 0 && (sniffing || file_handle->CanSeek())) { + if (pos != 0 && (sniffing || file_handle->CanSeek() || per_file_single_threaded)) { // We don't need to unpin the buffers here if we are not sniffing since we // control it per-thread on the scan if (cached_buffers[pos - 1]) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp index 5b970d8c..e8b502fb 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp @@ -1,17 +1,19 @@ #include "duckdb/execution/operator/csv_scanner/csv_file_handle.hpp" #include "duckdb/common/exception/binder_exception.hpp" #include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/compressed_file_system.hpp" +#include "duckdb/common/string_util.hpp" namespace duckdb { CSVFileHandle::CSVFileHandle(FileSystem &fs, Allocator &allocator, unique_ptr file_handle_p, const string &path_p, FileCompressionType compression) - : file_handle(std::move(file_handle_p)), path(path_p) { + : compression_type(compression), file_handle(std::move(file_handle_p)), path(path_p) { can_seek = file_handle->CanSeek(); on_disk_file = file_handle->OnDiskFile(); file_size = file_handle->GetFileSize(); is_pipe = file_handle->IsPipe(); - uncompressed = compression == FileCompressionType::UNCOMPRESSED; + compression_type = file_handle->GetFileCompressionType(); } unique_ptr CSVFileHandle::OpenFileHandle(FileSystem &fs, Allocator &allocator, const string &path, @@ -29,6 +31,10 @@ unique_ptr CSVFileHandle::OpenFile(FileSystem &fs, Allocator &all return make_uniq(fs, allocator, std::move(file_handle), path, compression); } +double CSVFileHandle::GetProgress() { + return static_cast(file_handle->GetProgress()); +} + bool CSVFileHandle::CanSeek() { return can_seek; } @@ -72,6 +78,7 @@ idx_t CSVFileHandle::Read(void *buffer, idx_t nr_bytes) { if (!finished) { finished = bytes_read == 0; } + uncompressed_bytes_read += static_cast(bytes_read); return UnsafeNumericCast(bytes_read); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp index 71a80064..c8fe1d6f 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp @@ -5,15 +5,15 @@ namespace duckdb { -ScannerResult::ScannerResult(CSVStates &states_p, CSVStateMachine &state_machine_p) - : state_machine(state_machine_p), states(states_p) { +ScannerResult::ScannerResult(CSVStates &states_p, CSVStateMachine &state_machine_p, idx_t result_size_p) + : result_size(result_size_p), state_machine(state_machine_p), states(states_p) { } BaseScanner::BaseScanner(shared_ptr buffer_manager_p, shared_ptr state_machine_p, shared_ptr error_handler_p, bool sniffing_p, shared_ptr csv_file_scan_p, CSVIterator iterator_p) : csv_file_scan(std::move(csv_file_scan_p)), sniffing(sniffing_p), error_handler(std::move(error_handler_p)), - state_machine(std::move(state_machine_p)), iterator(iterator_p), buffer_manager(std::move(buffer_manager_p)) { + state_machine(std::move(state_machine_p)), buffer_manager(std::move(buffer_manager_p)), iterator(iterator_p) { D_ASSERT(buffer_manager); D_ASSERT(state_machine); // Initialize current buffer handle @@ -41,19 +41,15 @@ bool BaseScanner::FinishedFile() { return iterator.pos.buffer_pos + 1 == cur_buffer_handle->actual_size; } -void BaseScanner::SkipCSVRows(idx_t rows_to_skip) { +CSVIterator BaseScanner::SkipCSVRows(shared_ptr buffer_manager, + const shared_ptr &state_machine, idx_t rows_to_skip) { if (rows_to_skip == 0) { - return; + return {}; } - SkipScanner row_skipper(buffer_manager, state_machine, error_handler, rows_to_skip); + auto error_handler = make_shared_ptr(); + SkipScanner row_skipper(std::move(buffer_manager), state_machine, error_handler, rows_to_skip); row_skipper.ParseChunk(); - iterator.pos.buffer_pos = row_skipper.GetIteratorPosition(); - if (row_skipper.state_machine->options.dialect_options.state_machine_options.new_line == - NewLineIdentifier::CARRY_ON && - row_skipper.states.states[1] == CSVState::CARRIAGE_RETURN) { - iterator.pos.buffer_pos++; - } - lines_read += row_skipper.GetLinesRead(); + return row_skipper.GetIterator(); } CSVIterator &BaseScanner::GetIterator() { @@ -65,19 +61,19 @@ void BaseScanner::SetIterator(const CSVIterator &it) { } ScannerResult &BaseScanner::ParseChunk() { - throw InternalException("ParseChunk() from CSV Base Scanner is mot implemented"); + throw InternalException("ParseChunk() from CSV Base Scanner is not implemented"); } ScannerResult &BaseScanner::GetResult() { - throw InternalException("GetResult() from CSV Base Scanner is mot implemented"); + throw InternalException("GetResult() from CSV Base Scanner is not implemented"); } void BaseScanner::Initialize() { - throw InternalException("Initialize() from CSV Base Scanner is mot implemented"); + throw InternalException("Initialize() from CSV Base Scanner is not implemented"); } void BaseScanner::FinalizeChunkProcess() { - throw InternalException("FinalizeChunkProcess() from CSV Base Scanner is mot implemented"); + throw InternalException("FinalizeChunkProcess() from CSV Base Scanner is not implemented"); } CSVStateMachine &BaseScanner::GetStateMachine() { diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp index 70542ef4..f66b180e 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp @@ -2,37 +2,66 @@ namespace duckdb { -ColumnCountResult::ColumnCountResult(CSVStates &states, CSVStateMachine &state_machine) - : ScannerResult(states, state_machine) { +ColumnCountResult::ColumnCountResult(CSVStates &states, CSVStateMachine &state_machine, idx_t result_size) + : ScannerResult(states, state_machine, result_size) { + column_counts.resize(result_size); } -void ColumnCountResult::AddValue(ColumnCountResult &result, const idx_t buffer_pos) { +void ColumnCountResult::AddValue(ColumnCountResult &result, idx_t buffer_pos) { result.current_column_count++; } inline void ColumnCountResult::InternalAddRow() { - column_counts[result_position++] = current_column_count + 1; + column_counts[result_position].number_of_columns = current_column_count + 1; current_column_count = 0; } -bool ColumnCountResult::AddRow(ColumnCountResult &result, const idx_t buffer_pos) { +bool ColumnCountResult::AddRow(ColumnCountResult &result, idx_t buffer_pos) { result.InternalAddRow(); if (!result.states.EmptyLastValue()) { - result.last_value_always_empty = false; + idx_t col_count_idx = result.result_position; + for (idx_t i = 0; i < result.result_position + 1; i++) { + if (!result.column_counts[col_count_idx].last_value_always_empty) { + break; + } + result.column_counts[col_count_idx--].last_value_always_empty = false; + } } - if (result.result_position >= STANDARD_VECTOR_SIZE) { + result.result_position++; + if (result.result_position >= result.result_size) { // We sniffed enough rows return true; } return false; } +void ColumnCountResult::SetComment(ColumnCountResult &result, idx_t buffer_pos) { + if (!result.states.WasStandard()) { + result.cur_line_starts_as_comment = true; + } + result.comment = true; +} + +bool ColumnCountResult::UnsetComment(ColumnCountResult &result, idx_t buffer_pos) { + // If we are unsetting a comment, it means this row started with a comment char. + // We add the row but tag it as a comment + bool done = result.AddRow(result, buffer_pos); + if (result.cur_line_starts_as_comment) { + result.column_counts[result.result_position - 1].is_comment = true; + } else { + result.column_counts[result.result_position - 1].is_mid_comment = true; + } + result.comment = false; + result.cur_line_starts_as_comment = false; + return done; +} + void ColumnCountResult::InvalidState(ColumnCountResult &result) { result.result_position = 0; result.error = true; } -bool ColumnCountResult::EmptyLine(ColumnCountResult &result, const idx_t buffer_pos) { +bool ColumnCountResult::EmptyLine(ColumnCountResult &result, idx_t buffer_pos) { // nop return false; } @@ -43,15 +72,24 @@ void ColumnCountResult::QuotedNewLine(ColumnCountResult &result) { ColumnCountScanner::ColumnCountScanner(shared_ptr buffer_manager, const shared_ptr &state_machine, - shared_ptr error_handler) - : BaseScanner(std::move(buffer_manager), state_machine, std::move(error_handler)), result(states, *state_machine), - column_count(1) { + shared_ptr error_handler, idx_t result_size_p, + CSVIterator iterator) + : BaseScanner(std::move(buffer_manager), state_machine, std::move(error_handler), true, nullptr, iterator), + result(states, *state_machine, result_size_p), column_count(1), result_size(result_size_p) { sniffing = true; } unique_ptr ColumnCountScanner::UpgradeToStringValueScanner() { - auto scanner = make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true); - return scanner; + idx_t rows_to_skip = + std::max(state_machine->dialect_options.skip_rows.GetValue(), state_machine->dialect_options.rows_until_header); + auto iterator = SkipCSVRows(buffer_manager, state_machine, rows_to_skip); + if (iterator.done) { + CSVIterator it {}; + return make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true, it, + result_size); + } + return make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true, iterator, + result_size); } ColumnCountResult &ColumnCountScanner::ParseChunk() { @@ -70,12 +108,12 @@ void ColumnCountScanner::Initialize() { } void ColumnCountScanner::FinalizeChunkProcess() { - if (result.result_position == STANDARD_VECTOR_SIZE || result.error) { + if (result.result_position == result.result_size || result.error) { // We are done return; } // We run until we have a full chunk, or we are done scanning - while (!FinishedFile() && result.result_position < STANDARD_VECTOR_SIZE && !result.error) { + while (!FinishedFile() && result.result_position < result.result_size && !result.error) { if (iterator.pos.buffer_pos == cur_buffer_handle->actual_size) { // Move to next buffer cur_buffer_handle = buffer_manager->GetBuffer(++iterator.pos.buffer_idx); @@ -85,7 +123,13 @@ void ColumnCountScanner::FinalizeChunkProcess() { return; } // This means we reached the end of the file, we must add a last line if there is any to be added - result.InternalAddRow(); + if (result.comment) { + // If it's a comment we add the last line via unsetcomment + result.UnsetComment(result, NumericLimits::Maximum()); + } else { + // OW, we do a regular AddRow + result.AddRow(result, NumericLimits::Maximum()); + } return; } iterator.pos.buffer_pos = 0; diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp new file mode 100644 index 00000000..5d6a9b0d --- /dev/null +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp @@ -0,0 +1,105 @@ +#include "duckdb/execution/operator/csv_scanner/csv_schema.hpp" + +namespace duckdb { + +struct TypeIdxPair { + TypeIdxPair(LogicalType type_p, idx_t idx_p) : type(std::move(type_p)), idx(idx_p) { + } + TypeIdxPair() { + } + LogicalType type; + idx_t idx {}; +}; + +// We only really care about types that can be set in the sniffer_auto, or are sniffed by default +// If the user manually sets them, we should never get a cast issue from the sniffer! +bool CSVSchema::CanWeCastIt(LogicalTypeId source, LogicalTypeId destination) { + if (destination == LogicalTypeId::VARCHAR || source == destination) { + // We can always cast to varchar + // And obviously don't have to do anything if they are equal. + return true; + } + switch (source) { + case LogicalTypeId::SQLNULL: + return true; + case LogicalTypeId::TINYINT: + return destination == LogicalTypeId::SMALLINT || destination == LogicalTypeId::INTEGER || + destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || + destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::SMALLINT: + return destination == LogicalTypeId::INTEGER || destination == LogicalTypeId::BIGINT || + destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || + destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::INTEGER: + return destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || + destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::BIGINT: + return destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || + destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::FLOAT: + return destination == LogicalTypeId::DOUBLE; + default: + return false; + } +} + +void CSVSchema::Initialize(vector &names, vector &types, const string &file_path_p) { + if (!columns.empty()) { + throw InternalException("CSV Schema is already populated, this should not happen."); + } + file_path = file_path_p; + D_ASSERT(names.size() == types.size() && !names.empty()); + for (idx_t i = 0; i < names.size(); i++) { + // Populate our little schema + columns.push_back({names[i], types[i]}); + name_idx_map[names[i]] = i; + } +} + +bool CSVSchema::Empty() const { + return columns.empty(); +} + +bool CSVSchema::SchemasMatch(string &error_message, vector &names, vector &types, + const string &cur_file_path) { + D_ASSERT(names.size() == types.size()); + bool match = true; + unordered_map current_schema; + for (idx_t i = 0; i < names.size(); i++) { + // Populate our little schema + current_schema[names[i]] = {types[i], i}; + } + // Here we check if the schema of a given file matched our original schema + // We consider it's not a match if: + // 1. The file misses columns that were defined in the original schema. + // 2. They have a column match, but the types do not match. + std::ostringstream error; + error << "Schema mismatch between globbed files." + << "\n"; + error << "Main file schema: " << file_path << "\n"; + error << "Current file: " << cur_file_path << "\n"; + + for (auto &column : columns) { + if (current_schema.find(column.name) == current_schema.end()) { + error << "Column with name: \"" << column.name << "\" is missing" + << "\n"; + match = false; + } else { + if (!CanWeCastIt(current_schema[column.name].type.id(), column.type.id())) { + error << "Column with name: \"" << column.name + << "\" is expected to have type: " << column.type.ToString(); + error << " But has type: " << current_schema[column.name].type.ToString() << "\n"; + match = false; + } + } + } + + // Lets suggest some potential fixes + error << "Potential Fix: Since your schema has a mismatch, consider setting union_by_name=true."; + if (!match) { + error_message = error.str(); + } + return match; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp index aa2c3aea..4ddd938b 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp @@ -2,28 +2,15 @@ namespace duckdb { -CSVPosition::CSVPosition(idx_t file_idx_p, idx_t buffer_idx_p, idx_t buffer_pos_p) - : file_idx(file_idx_p), buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p) { +CSVPosition::CSVPosition(idx_t buffer_idx_p, idx_t buffer_pos_p) : buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p) { } CSVPosition::CSVPosition() { } -CSVBoundary::CSVBoundary(idx_t file_idx_p, idx_t buffer_idx_p, idx_t buffer_pos_p, idx_t boundary_idx_p, - idx_t end_pos_p) - : file_idx(file_idx_p), buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p), boundary_idx(boundary_idx_p), - end_pos(end_pos_p) { +CSVBoundary::CSVBoundary(idx_t buffer_idx_p, idx_t buffer_pos_p, idx_t boundary_idx_p, idx_t end_pos_p) + : buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p), boundary_idx(boundary_idx_p), end_pos(end_pos_p) { } -CSVBoundary::CSVBoundary() - : file_idx(0), buffer_idx(0), buffer_pos(0), boundary_idx(0), end_pos(NumericLimits::Maximum()) { -} -CSVIterator::CSVIterator(idx_t file_idx, idx_t buffer_idx, idx_t buffer_pos, idx_t boundary_idx, idx_t buffer_size) - : pos(file_idx, buffer_idx, buffer_pos), is_set(true) { - // The end of our boundary will be the buffer size itself it that's smaller than where we want to go - if (buffer_size < buffer_pos + BYTES_PER_THREAD) { - boundary = {file_idx, buffer_idx, buffer_pos, boundary_idx, buffer_size}; - } else { - boundary = {file_idx, buffer_idx, buffer_pos, boundary_idx, buffer_pos + BYTES_PER_THREAD}; - } +CSVBoundary::CSVBoundary() : buffer_idx(0), buffer_pos(0), boundary_idx(0), end_pos(NumericLimits::Maximum()) { } CSVIterator::CSVIterator() : is_set(false) { @@ -32,7 +19,6 @@ CSVIterator::CSVIterator() : is_set(false) { void CSVBoundary::Print() { #ifndef DUCKDB_DISABLE_PRINT std::cout << "---Boundary: " << boundary_idx << " ---" << '\n'; - std::cout << "File Index:: " << file_idx << '\n'; std::cout << "Buffer Index: " << buffer_idx << '\n'; std::cout << "Buffer Pos: " << buffer_pos << '\n'; std::cout << "End Pos: " << end_pos << '\n'; @@ -51,6 +37,8 @@ bool CSVIterator::Next(CSVBufferManager &buffer_manager) { if (!is_set) { return false; } + // If we are calling next this is not the first one anymore + first_one = false; boundary.boundary_idx++; // This is our start buffer auto buffer = buffer_manager.GetBuffer(boundary.buffer_idx); @@ -84,12 +72,8 @@ idx_t CSVIterator::GetEndPos() const { return boundary.end_pos; } -idx_t CSVIterator::GetFileIdx() const { - return pos.file_idx; -} - idx_t CSVIterator::GetBufferIdx() const { - return boundary.buffer_idx; + return pos.buffer_idx; } idx_t CSVIterator::GetBoundaryIdx() const { @@ -97,11 +81,27 @@ idx_t CSVIterator::GetBoundaryIdx() const { } void CSVIterator::SetCurrentPositionToBoundary() { - pos.file_idx = boundary.file_idx; pos.buffer_idx = boundary.buffer_idx; pos.buffer_pos = boundary.buffer_pos; } +void CSVIterator::SetCurrentBoundaryToPosition(bool single_threaded) { + if (single_threaded) { + is_set = false; + return; + } + boundary.buffer_idx = pos.buffer_idx; + if (pos.buffer_pos == 0) { + boundary.end_pos = CSVIterator::BYTES_PER_THREAD; + } else { + boundary.end_pos = ((pos.buffer_pos + CSVIterator::BYTES_PER_THREAD - 1) / CSVIterator::BYTES_PER_THREAD) * + CSVIterator::BYTES_PER_THREAD; + } + + boundary.buffer_pos = boundary.end_pos - CSVIterator::BYTES_PER_THREAD; + is_set = true; +} + void CSVIterator::SetStart(idx_t start) { boundary.buffer_pos = start; } diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp index 486c26b9..3afe22d6 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp @@ -4,7 +4,7 @@ namespace duckdb { SkipResult::SkipResult(CSVStates &states, CSVStateMachine &state_machine, idx_t rows_to_skip_p) - : ScannerResult(states, state_machine), rows_to_skip(rows_to_skip_p) { + : ScannerResult(states, state_machine, STANDARD_VECTOR_SIZE), rows_to_skip(rows_to_skip_p) { } void SkipResult::AddValue(SkipResult &result, const idx_t buffer_pos) { @@ -19,6 +19,14 @@ void SkipResult::QuotedNewLine(SkipResult &result) { // nop } +bool SkipResult::UnsetComment(SkipResult &result, idx_t buffer_pos) { + // If we are unsetting a comment, it means this row started with a comment char. + // We add the row but tag it as a comment + bool done = result.AddRow(result, buffer_pos); + result.comment = false; + return done; +} + bool SkipResult::AddRow(SkipResult &result, const idx_t buffer_pos) { result.InternalAddRow(); if (result.row_count >= result.rows_to_skip) { @@ -38,6 +46,7 @@ bool SkipResult::EmptyLine(SkipResult &result, const idx_t buffer_pos) { } return false; } + SkipScanner::SkipScanner(shared_ptr buffer_manager, const shared_ptr &state_machine, shared_ptr error_handler, idx_t rows_to_skip) : BaseScanner(std::move(buffer_manager), state_machine, std::move(error_handler)), @@ -58,6 +67,20 @@ void SkipScanner::Initialize() { } void SkipScanner::FinalizeChunkProcess() { - // nop + // We continue skipping until we skipped enough rows, or we have nothing else to read. + while (!FinishedFile() && result.row_count < result.rows_to_skip) { + cur_buffer_handle = buffer_manager->GetBuffer(++iterator.pos.buffer_idx); + if (cur_buffer_handle) { + iterator.pos.buffer_pos = 0; + buffer_handle_ptr = cur_buffer_handle->Ptr(); + Process(result); + } + } + // Skip Carriage Return + if (state_machine->options.dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON && + states.states[1] == CSVState::CARRIAGE_RETURN) { + iterator.pos.buffer_pos++; + } + iterator.done = FinishedFile(); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index d9363074..876fd024 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -14,18 +14,18 @@ #include namespace duckdb { - StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_machine, const shared_ptr &buffer_handle, Allocator &buffer_allocator, - bool figure_out_new_line_p, idx_t buffer_position, CSVErrorHandler &error_hander_p, + idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_hander_p, CSVIterator &iterator_p, bool store_line_size_p, - shared_ptr csv_file_scan_p, idx_t &lines_read_p, bool sniffing_p) - : ScannerResult(states, state_machine), + shared_ptr csv_file_scan_p, idx_t &lines_read_p, bool sniffing_p, + string path_p) + : ScannerResult(states, state_machine, result_size_p), number_of_columns(NumericCast(state_machine.dialect_options.num_cols)), null_padding(state_machine.options.null_padding), ignore_errors(state_machine.options.ignore_errors.GetValue()), - figure_out_new_line(figure_out_new_line_p), error_handler(error_hander_p), iterator(iterator_p), - store_line_size(store_line_size_p), csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), - current_errors(state_machine.options.IgnoreErrors()), sniffing(sniffing_p) { + error_handler(error_hander_p), iterator(iterator_p), store_line_size(store_line_size_p), + csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), + current_errors(state_machine.options.IgnoreErrors()), sniffing(sniffing_p), path(std::move(path_p)) { // Vector information D_ASSERT(number_of_columns > 0); buffer_handles[buffer_handle->buffer_idx] = buffer_handle; @@ -34,8 +34,6 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m buffer_size = buffer_handle->actual_size; last_position = {buffer_handle->buffer_idx, buffer_position, buffer_size}; requested_size = buffer_handle->requested_size; - result_size = figure_out_new_line ? 1 : STANDARD_VECTOR_SIZE; - // Current Result information current_line_position.begin = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, buffer_handle->actual_size}; current_line_position.end = current_line_position.begin; @@ -56,9 +54,13 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m "Mismatch between the number of columns (%d) in the CSV file and what is expected in the scanner (%d).", number_of_columns, csv_file_scan->file_types.size()); } + bool icu_loaded = csv_file_scan->buffer_manager->context.db->ExtensionIsLoaded("icu"); for (idx_t i = 0; i < csv_file_scan->file_types.size(); i++) { auto &type = csv_file_scan->file_types[i]; - if (StringValueScanner::CanDirectlyCast(type)) { + if (type.IsJSONType()) { + type = LogicalType::VARCHAR; + } + if (StringValueScanner::CanDirectlyCast(type, icu_loaded)) { parse_types[i] = ParseTypeInfo(type, true); logical_types.emplace_back(type); } else { @@ -98,8 +100,8 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m // Setup the NullStr information null_str_count = state_machine.options.null_str.size(); - null_str_ptr = make_unsafe_uniq_array(null_str_count); - null_str_size = make_unsafe_uniq_array(null_str_count); + null_str_ptr = make_unsafe_uniq_array_uninitialized(null_str_count); + null_str_size = make_unsafe_uniq_array_uninitialized(null_str_count); for (idx_t i = 0; i < null_str_count; i++) { null_str_ptr[i] = state_machine.options.null_str[i].c_str(); null_str_size[i] = state_machine.options.null_str[i].size(); @@ -107,6 +109,14 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m date_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::DATE).GetValue(); timestamp_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::TIMESTAMP).GetValue(); decimal_separator = state_machine.options.decimal_separator[0]; + + if (iterator.first_one) { + lines_read += + state_machine.dialect_options.skip_rows.GetValue() + state_machine.dialect_options.header.GetValue(); + if (lines_read == 0) { + SkipBOM(); + } + } } StringValueResult::~StringValueResult() { @@ -141,7 +151,7 @@ bool StringValueResult::HandleTooManyColumnsError(const char *value_ptr, const i } if (error) { // We error pointing to the current value error. - current_errors.Insert(CSVErrorType::TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); + current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); cur_col_id++; } // We had an error @@ -149,6 +159,53 @@ bool StringValueResult::HandleTooManyColumnsError(const char *value_ptr, const i } return false; } + +void StringValueResult::SetComment(StringValueResult &result, idx_t buffer_pos) { + if (!result.comment) { + result.position_before_comment = buffer_pos; + result.comment = true; + } +} + +bool StringValueResult::UnsetComment(StringValueResult &result, idx_t buffer_pos) { + bool done = false; + if (result.last_position.buffer_pos < result.position_before_comment) { + bool all_empty = true; + for (idx_t i = result.last_position.buffer_pos; i < result.position_before_comment; i++) { + if (result.buffer_ptr[i] != ' ') { + all_empty = false; + break; + } + } + if (!all_empty) { + done = AddRow(result, result.position_before_comment); + } + } else { + if (result.cur_col_id != 0) { + done = AddRow(result, result.position_before_comment); + } + } + if (result.number_of_rows == 0) { + result.first_line_is_comment = true; + } + result.comment = false; + if (result.state_machine.dialect_options.state_machine_options.new_line.GetValue() != NewLineIdentifier::CARRY_ON) { + result.last_position.buffer_pos = buffer_pos + 1; + } else { + result.last_position.buffer_pos = buffer_pos + 2; + } + result.cur_col_id = 0; + result.chunk_col_id = 0; + return done; +} + +static void SanitizeError(string &value) { + std::vector char_array(value.begin(), value.end()); + char_array.push_back('\0'); // Null-terminate the character array + Utf8Proc::MakeValid(&char_array[0], char_array.size()); + value = {char_array.begin(), char_array.end() - 1}; +} + void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size, bool allocate) { if (HandleTooManyColumnsError(value_ptr, size)) { return; @@ -165,7 +222,7 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size } if (error) { // We error pointing to the current value error. - current_errors.Insert(CSVErrorType::TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); + current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); cur_col_id++; } return; @@ -188,7 +245,7 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size if (empty) { if (parse_types[chunk_col_id].type_id != LogicalTypeId::VARCHAR) { // If it is not a varchar, empty values are not accepted, we must error. - current_errors.Insert(CSVErrorType::CAST_ERROR, cur_col_id, chunk_col_id, last_position); + current_errors.Insert(CAST_ERROR, cur_col_id, chunk_col_id, last_position); } static_cast(vector_ptr[chunk_col_id])[number_of_rows] = string_t(); } else { @@ -266,7 +323,8 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); break; } - case LogicalTypeId::TIMESTAMP: { + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: { if (!timestamp_format.Empty()) { success = timestamp_format.TryParseTimestamp( value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]); @@ -347,7 +405,7 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size HandleUnicodeError(cur_col_id, last_position); } // If we got here, we are ingoring errors, hence we must ignore this line. - current_errors.Insert(CSVErrorType::INVALID_UNICODE, cur_col_id, chunk_col_id, last_position); + current_errors.Insert(INVALID_UNICODE, cur_col_id, chunk_col_id, last_position); break; } if (allocate) { @@ -362,14 +420,17 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size } } if (!success) { - current_errors.Insert(CSVErrorType::CAST_ERROR, cur_col_id, chunk_col_id, last_position); + current_errors.Insert(CAST_ERROR, cur_col_id, chunk_col_id, last_position); if (!state_machine.options.IgnoreErrors()) { // We have to write the cast error message. std::ostringstream error; // Casting Error Message error << "Could not convert string \"" << std::string(value_ptr, size) << "\" to \'" << LogicalTypeIdToString(parse_types[chunk_col_id].type_id) << "\'"; - current_errors.ModifyErrorMessageOfLastError(error.str()); + auto error_string = error.str(); + SanitizeError(error_string); + + current_errors.ModifyErrorMessageOfLastError(error_string); } } cur_col_id++; @@ -417,11 +478,32 @@ void StringValueResult::AddQuotedValue(StringValueResult &result, const idx_t bu if (!result.HandleTooManyColumnsError(result.buffer_ptr + result.quoted_position + 1, buffer_pos - result.quoted_position - 2)) { // If it's an escaped value we have to remove all the escapes, this is not really great - auto value = StringValueScanner::RemoveEscape( - result.buffer_ptr + result.quoted_position + 1, buffer_pos - result.quoted_position - 2, - result.state_machine.dialect_options.state_machine_options.escape.GetValue(), - result.parse_chunk.data[result.chunk_col_id]); - result.AddValueToVector(value.GetData(), value.GetSize()); + // If we are going to escape, this vector must be a varchar vector + if (result.parse_chunk.data[result.chunk_col_id].GetType() != LogicalType::VARCHAR) { + result.current_errors.Insert(CAST_ERROR, result.cur_col_id, result.chunk_col_id, result.last_position); + if (!result.state_machine.options.IgnoreErrors()) { + // We have to write the cast error message. + std::ostringstream error; + // Casting Error Message + + error << "Could not convert string \"" + << std::string(result.buffer_ptr + result.quoted_position + 1, + buffer_pos - result.quoted_position - 2) + << "\" to \'" << LogicalTypeIdToString(result.parse_types[result.chunk_col_id].type_id) + << "\'"; + auto error_string = error.str(); + SanitizeError(error_string); + result.current_errors.ModifyErrorMessageOfLastError(error_string); + } + result.cur_col_id++; + result.chunk_col_id++; + } else { + auto value = StringValueScanner::RemoveEscape( + result.buffer_ptr + result.quoted_position + 1, buffer_pos - result.quoted_position - 2, + result.state_machine.dialect_options.state_machine_options.escape.GetValue(), + result.parse_chunk.data[result.chunk_col_id]); + result.AddValueToVector(value.GetData(), value.GetSize()); + } } } else { if (buffer_pos < result.last_position.buffer_pos + 2) { @@ -453,100 +535,108 @@ void StringValueResult::AddValue(StringValueResult &result, const idx_t buffer_p void StringValueResult::HandleUnicodeError(idx_t col_idx, LinePosition &error_position) { bool first_nl; - auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles); + auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); if (current_line_position.begin == error_position) { auto csv_error = CSVError::InvalidUTF8(state_machine.options, col_idx, lines_per_batch, borked_line, current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - error_position.GetGlobalPosition(requested_size, first_nl)); + error_position.GetGlobalPosition(requested_size, first_nl), path); error_handler.Error(csv_error, true); } else { auto csv_error = CSVError::InvalidUTF8(state_machine.options, col_idx, lines_per_batch, borked_line, current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - error_position.GetGlobalPosition(requested_size)); + error_position.GetGlobalPosition(requested_size), path); error_handler.Error(csv_error, true); } } bool LineError::HandleErrors(StringValueResult &result) { - if (ignore_errors && is_error_in_line && !result.figure_out_new_line) { - result.cur_col_id = 0; - result.chunk_col_id = 0; - result.number_of_rows--; + bool skip_sniffing = false; + for (auto &cur_error : current_errors) { + if (cur_error.type == CSVErrorType::INVALID_UNICODE) { + skip_sniffing = true; + } + } + skip_sniffing = result.sniffing && skip_sniffing; + + if ((ignore_errors || skip_sniffing) && is_error_in_line && !result.figure_out_new_line) { + result.RemoveLastLine(); Reset(); return true; } // Reconstruct CSV Line for (auto &cur_error : current_errors) { LinesPerBoundary lines_per_batch(result.iterator.GetBoundaryIdx(), result.lines_read); - bool first_nl; - auto borked_line = result.current_line_position.ReconstructCurrentLine(first_nl, result.buffer_handles); + bool first_nl = false; + auto borked_line = result.current_line_position.ReconstructCurrentLine(first_nl, result.buffer_handles, + result.PrintErrorLine()); CSVError csv_error; auto col_idx = cur_error.col_idx; auto &line_pos = cur_error.error_position; switch (cur_error.type) { - case CSVErrorType::TOO_MANY_COLUMNS: - case CSVErrorType::TOO_FEW_COLUMNS: + case TOO_MANY_COLUMNS: + case TOO_FEW_COLUMNS: if (result.current_line_position.begin == line_pos) { csv_error = CSVError::IncorrectColumnAmountError( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl)); + line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); } else { csv_error = CSVError::IncorrectColumnAmountError( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size)); + line_pos.GetGlobalPosition(result.requested_size), result.path); } break; - case CSVErrorType::INVALID_UNICODE: { + case INVALID_UNICODE: { if (result.current_line_position.begin == line_pos) { csv_error = CSVError::InvalidUTF8( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl)); + line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); } else { csv_error = CSVError::InvalidUTF8( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size)); + line_pos.GetGlobalPosition(result.requested_size), result.path); } break; } - case CSVErrorType::UNTERMINATED_QUOTES: + case UNTERMINATED_QUOTES: if (result.current_line_position.begin == line_pos) { csv_error = CSVError::UnterminatedQuotesError( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl)); + line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); } else { csv_error = CSVError::UnterminatedQuotesError( result.state_machine.options, col_idx, lines_per_batch, borked_line, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size)); + line_pos.GetGlobalPosition(result.requested_size), result.path); } break; - case CSVErrorType::CAST_ERROR: + case CAST_ERROR: if (result.current_line_position.begin == line_pos) { csv_error = CSVError::CastError( result.state_machine.options, result.names[cur_error.col_idx], cur_error.error_message, cur_error.col_idx, borked_line, lines_per_batch, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), line_pos.GetGlobalPosition(result.requested_size, first_nl), - result.parse_types[cur_error.chunk_idx].type_id); + result.parse_types[cur_error.chunk_idx].type_id, result.path); } else { csv_error = CSVError::CastError( result.state_machine.options, result.names[cur_error.col_idx], cur_error.error_message, cur_error.col_idx, borked_line, lines_per_batch, result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.parse_types[cur_error.chunk_idx].type_id); + line_pos.GetGlobalPosition(result.requested_size), result.parse_types[cur_error.chunk_idx].type_id, + result.path); } break; - case CSVErrorType::MAXIMUM_LINE_SIZE: + case MAXIMUM_LINE_SIZE: csv_error = CSVError::LineSizeError( result.state_machine.options, cur_error.current_line_size, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl)); + result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), result.path); break; default: throw InvalidInputException("CSV Error not allowed when inserting row"); @@ -567,19 +657,23 @@ void StringValueResult::QuotedNewLine(StringValueResult &result) { result.quoted_new_line = true; } -void StringValueResult::NullPaddingQuotedNewlineCheck() { +void StringValueResult::NullPaddingQuotedNewlineCheck() const { // We do some checks for null_padding correctness if (state_machine.options.null_padding && iterator.IsBoundarySet() && quoted_new_line) { // If we have null_padding set, we found a quoted new line, we are scanning the file in parallel; We error. LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); - auto csv_error = CSVError::NullPaddingFail(state_machine.options, lines_per_batch); + auto csv_error = CSVError::NullPaddingFail(state_machine.options, lines_per_batch, path); error_handler.Error(csv_error); } } //! Reconstructs the current line to be used in error messages string FullLinePosition::ReconstructCurrentLine(bool &first_char_nl, - unordered_map> &buffer_handles) { + unordered_map> &buffer_handles, + bool reconstruct_line) const { + if (!reconstruct_line) { + return {}; + } string result; if (end.buffer_idx == begin.buffer_idx) { if (buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { @@ -609,10 +703,7 @@ string FullLinePosition::ReconstructCurrentLine(bool &first_char_nl, } } // sanitize borked line - std::vector char_array(result.begin(), result.end()); - char_array.push_back('\0'); // Null-terminate the character array - Utf8Proc::MakeValid(&char_array[0], char_array.size()); - result = {char_array.begin(), char_array.end() - 1}; + SanitizeError(result); return result; } @@ -625,11 +716,11 @@ bool StringValueResult::AddRowInternal() { current_line_position.begin = current_line_position.end; current_line_position.end = current_line_start; if (current_line_size > state_machine.options.maximum_line_size) { - current_errors.Insert(CSVErrorType::MAXIMUM_LINE_SIZE, 1, chunk_col_id, last_position, current_line_size); + current_errors.Insert(MAXIMUM_LINE_SIZE, 1, chunk_col_id, last_position, current_line_size); } if (!state_machine.options.null_padding) { for (idx_t col_idx = cur_col_id; col_idx < number_of_columns; col_idx++) { - current_errors.Insert(CSVErrorType::TOO_FEW_COLUMNS, col_idx - 1, chunk_col_id, last_position); + current_errors.Insert(TOO_FEW_COLUMNS, col_idx - 1, chunk_col_id, last_position); } } @@ -672,24 +763,25 @@ bool StringValueResult::AddRowInternal() { // If we are not null-padding this is an error if (!state_machine.options.IgnoreErrors()) { bool first_nl; - auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles); + auto borked_line = + current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); if (current_line_position.begin == last_position) { auto csv_error = CSVError::IncorrectColumnAmountError( state_machine.options, cur_col_id - 1, lines_per_batch, borked_line, current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - last_position.GetGlobalPosition(requested_size, first_nl)); + last_position.GetGlobalPosition(requested_size, first_nl), path); error_handler.Error(csv_error); } else { auto csv_error = CSVError::IncorrectColumnAmountError( state_machine.options, cur_col_id - 1, lines_per_batch, borked_line, current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - last_position.GetGlobalPosition(requested_size)); + last_position.GetGlobalPosition(requested_size), path); error_handler.Error(csv_error); } } // If we are here we ignore_errors, so we delete this line - number_of_rows--; + RemoveLastLine(); } } line_positions_per_row[number_of_rows] = current_line_position; @@ -707,7 +799,7 @@ bool StringValueResult::AddRow(StringValueResult &result, const idx_t buffer_pos if (result.last_position.buffer_pos <= buffer_pos) { // We add the value if (result.quoted) { - StringValueResult::AddQuotedValue(result, buffer_pos); + AddQuotedValue(result, buffer_pos); } else { result.AddValueToVector(result.buffer_ptr + result.last_position.buffer_pos, buffer_pos - result.last_position.buffer_pos); @@ -734,8 +826,7 @@ void StringValueResult::InvalidState(StringValueResult &result) { if (force_error) { result.HandleUnicodeError(result.cur_col_id, result.last_position); } - result.current_errors.Insert(CSVErrorType::UNTERMINATED_QUOTES, result.cur_col_id, result.chunk_col_id, - result.last_position); + result.current_errors.Insert(UNTERMINATED_QUOTES, result.cur_col_id, result.chunk_col_id, result.last_position); } bool StringValueResult::EmptyLine(StringValueResult &result, const idx_t buffer_pos) { @@ -772,21 +863,24 @@ StringValueScanner::StringValueScanner(idx_t scanner_idx_p, const shared_ptr &state_machine, const shared_ptr &error_handler, const shared_ptr &csv_file_scan, bool sniffing, - CSVIterator boundary, bool figure_out_nl) + const CSVIterator &boundary, idx_t result_size) : BaseScanner(buffer_manager, state_machine, error_handler, sniffing, csv_file_scan, boundary), scanner_idx(scanner_idx_p), - result(states, *state_machine, cur_buffer_handle, BufferAllocator::Get(buffer_manager->context), figure_out_nl, + result(states, *state_machine, cur_buffer_handle, BufferAllocator::Get(buffer_manager->context), result_size, iterator.pos.buffer_pos, *error_handler, iterator, - buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing) { + buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, + buffer_manager->GetFilePath()) { } StringValueScanner::StringValueScanner(const shared_ptr &buffer_manager, const shared_ptr &state_machine, - const shared_ptr &error_handler) - : BaseScanner(buffer_manager, state_machine, error_handler, false, nullptr, {}), scanner_idx(0), - result(states, *state_machine, cur_buffer_handle, Allocator::DefaultAllocator(), false, iterator.pos.buffer_pos, - *error_handler, iterator, buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, - lines_read, sniffing) { + const shared_ptr &error_handler, idx_t result_size, + const CSVIterator &boundary) + : BaseScanner(buffer_manager, state_machine, error_handler, false, nullptr, boundary), scanner_idx(0), + result(states, *state_machine, cur_buffer_handle, Allocator::DefaultAllocator(), result_size, + iterator.pos.buffer_pos, *error_handler, iterator, + buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, + buffer_manager->GetFilePath()) { } unique_ptr StringValueScanner::GetCSVScanner(ClientContext &context, CSVReaderOptions &options) { @@ -796,13 +890,18 @@ unique_ptr StringValueScanner::GetCSVScanner(ClientContext & state_machine->dialect_options.num_cols = options.dialect_options.num_cols; state_machine->dialect_options.header = options.dialect_options.header; auto buffer_manager = make_shared_ptr(context, options, options.file_path, 0); - auto scanner = make_uniq(buffer_manager, state_machine, make_shared_ptr()); + idx_t rows_to_skip = state_machine->options.GetSkipRows() + state_machine->options.GetHeader(); + rows_to_skip = std::max(rows_to_skip, state_machine->dialect_options.rows_until_header + + state_machine->dialect_options.header.GetValue()); + auto it = BaseScanner::SkipCSVRows(buffer_manager, state_machine, rows_to_skip); + auto scanner = make_uniq(buffer_manager, state_machine, make_shared_ptr(), + STANDARD_VECTOR_SIZE, it); scanner->csv_file_scan = make_shared_ptr(context, options.file_path, options); scanner->csv_file_scan->InitializeProjection(); return scanner; } -bool StringValueScanner::FinishedIterator() { +bool StringValueScanner::FinishedIterator() const { return iterator.done; } @@ -842,7 +941,8 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { auto &result_vector = insert_chunk.data[result_idx]; auto &type = result_vector.GetType(); auto &parse_type = parse_vector.GetType(); - if (type == LogicalType::VARCHAR || (type != LogicalType::VARCHAR && parse_type != LogicalType::VARCHAR)) { + if (!type.IsJSONType() && + (type == LogicalType::VARCHAR || (type != LogicalType::VARCHAR && parse_type != LogicalType::VARCHAR))) { // reinterpret rather than reference result_vector.Reinterpret(parse_vector); } else { @@ -865,9 +965,9 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { } } { - vector row; if (state_machine->options.ignore_errors.GetValue()) { + vector row; for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { row.push_back(parse_chunk.GetValue(col, line_error)); } @@ -877,16 +977,17 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { lines_read - parse_chunk.size() + line_error); bool first_nl; auto borked_line = result.line_positions_per_row[line_error].ReconstructCurrentLine( - first_nl, result.buffer_handles); + first_nl, result.buffer_handles, result.PrintErrorLine()); std::ostringstream error; error << "Could not convert string \"" << parse_vector.GetValue(line_error) << "\" to \'" - << LogicalTypeIdToString(type.id()) << "\'"; + << type.ToString() << "\'"; string error_msg = error.str(); + SanitizeError(error_msg); auto csv_error = CSVError::CastError( state_machine->options, csv_file_scan->names[col_idx], error_msg, col_idx, borked_line, lines_per_batch, result.line_positions_per_row[line_error].begin.GetGlobalPosition(result.result_size, first_nl), - optional_idx::Invalid(), result_vector.GetType().id()); + optional_idx::Invalid(), result_vector.GetType().id(), result.path); error_handler->Error(csv_error); } } @@ -906,18 +1007,19 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { lines_read - parse_chunk.size() + line_error); bool first_nl; auto borked_line = result.line_positions_per_row[line_error].ReconstructCurrentLine( - first_nl, result.buffer_handles); + first_nl, result.buffer_handles, result.PrintErrorLine()); std::ostringstream error; // Casting Error Message error << "Could not convert string \"" << parse_vector.GetValue(line_error) << "\" to \'" << LogicalTypeIdToString(type.id()) << "\'"; string error_msg = error.str(); + SanitizeError(error_msg); auto csv_error = CSVError::CastError(state_machine->options, csv_file_scan->names[col_idx], error_msg, col_idx, borked_line, lines_per_batch, result.line_positions_per_row[line_error].begin.GetGlobalPosition( result.result_size, first_nl), - optional_idx::Invalid(), result_vector.GetType().id()); + optional_idx::Invalid(), result_vector.GetType().id(), result.path); error_handler->Error(csv_error); } } @@ -940,7 +1042,6 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { void StringValueScanner::Initialize() { states.Initialize(); - if (result.result_size != 1 && !(sniffing && state_machine->options.null_padding && !state_machine->options.dialect_options.skip_rows.IsSetByUser())) { SetStart(); @@ -968,7 +1069,11 @@ void StringValueScanner::ProcessExtraRow() { lines_read++; return; } else if (states.states[0] != CSVState::CARRIAGE_RETURN) { - result.AddRow(result, iterator.pos.buffer_pos); + if (result.IsCommentSet(result)) { + result.UnsetComment(result, iterator.pos.buffer_pos); + } else { + result.AddRow(result, iterator.pos.buffer_pos); + } iterator.pos.buffer_pos++; lines_read++; return; @@ -978,7 +1083,11 @@ void StringValueScanner::ProcessExtraRow() { break; case CSVState::CARRIAGE_RETURN: if (states.states[0] != CSVState::RECORD_SEPARATOR) { - result.AddRow(result, iterator.pos.buffer_pos); + if (result.IsCommentSet(result)) { + result.UnsetComment(result, iterator.pos.buffer_pos); + } else { + result.AddRow(result, iterator.pos.buffer_pos); + } iterator.pos.buffer_pos++; lines_read++; return; @@ -988,6 +1097,7 @@ void StringValueScanner::ProcessExtraRow() { lines_read++; return; } + break; case CSVState::DELIMITER: result.AddValue(result, iterator.pos.buffer_pos); iterator.pos.buffer_pos++; @@ -1016,6 +1126,15 @@ void StringValueScanner::ProcessExtraRow() { iterator.pos.buffer_pos++; } break; + case CSVState::COMMENT: + result.SetComment(result, iterator.pos.buffer_pos); + iterator.pos.buffer_pos++; + while (state_machine->transition_array + .skip_comment[static_cast(buffer_handle_ptr[iterator.pos.buffer_pos])] && + iterator.pos.buffer_pos < to_pos - 1) { + iterator.pos.buffer_pos++; + } + break; case CSVState::QUOTED_NEW_LINE: result.quoted_new_line = true; result.NullPaddingQuotedNewlineCheck(); @@ -1077,7 +1196,9 @@ void StringValueScanner::ProcessOverbufferValue() { if (states.NewRow() || states.NewValue()) { break; } else { - overbuffer_string += previous_buffer[i]; + if (!result.comment) { + overbuffer_string += previous_buffer[i]; + } } if (states.IsQuoted()) { result.SetQuoted(result, j); @@ -1085,6 +1206,9 @@ void StringValueScanner::ProcessOverbufferValue() { if (states.IsEscaped()) { result.escaped = true; } + if (states.IsComment()) { + result.comment = true; + } if (states.IsInvalid()) { result.InvalidState(result); } @@ -1109,11 +1233,16 @@ void StringValueScanner::ProcessOverbufferValue() { if (states.NewRow() || states.NewValue()) { break; } else { - overbuffer_string += buffer_handle_ptr[iterator.pos.buffer_pos]; + if (!result.comment && !states.IsComment()) { + overbuffer_string += buffer_handle_ptr[iterator.pos.buffer_pos]; + } } if (states.IsQuoted()) { result.SetQuoted(result, j); } + if (states.IsComment()) { + result.comment = true; + } if (states.IsEscaped()) { result.escaped = true; } @@ -1124,7 +1253,7 @@ void StringValueScanner::ProcessOverbufferValue() { } bool skip_value = false; if (result.projecting_columns) { - if (!result.projected_columns[result.cur_col_id]) { + if (!result.projected_columns[result.cur_col_id] && result.cur_col_id != result.number_of_columns) { result.cur_col_id++; skip_value = true; } @@ -1135,18 +1264,17 @@ void StringValueScanner::ProcessOverbufferValue() { value = string_t(overbuffer_string.c_str() + result.quoted_position, UnsafeNumericCast(overbuffer_string.size() - 1 - result.quoted_position)); if (result.escaped) { - const auto str_ptr = static_cast(overbuffer_string.c_str() + result.quoted_position); - value = StringValueScanner::RemoveEscape( - str_ptr, overbuffer_string.size() - 2, - state_machine->dialect_options.state_machine_options.escape.GetValue(), - result.parse_chunk.data[result.chunk_col_id]); + const auto str_ptr = overbuffer_string.c_str() + result.quoted_position; + value = RemoveEscape(str_ptr, overbuffer_string.size() - 2, + state_machine->dialect_options.state_machine_options.escape.GetValue(), + result.parse_chunk.data[result.chunk_col_id]); } } else { value = string_t(overbuffer_string.c_str(), UnsafeNumericCast(overbuffer_string.size())); } if (states.EmptyLine() && state_machine->dialect_options.num_cols == 1) { result.EmptyLine(result, iterator.pos.buffer_pos); - } else if (!states.IsNotSet()) { + } else if (!states.IsNotSet() && (!result.comment || !value.Empty())) { result.AddValueToVector(value.GetData(), value.GetSize(), true); } } else { @@ -1156,7 +1284,11 @@ void StringValueScanner::ProcessOverbufferValue() { } if (states.NewRow() && !states.IsNotSet()) { - result.AddRowInternal(); + if (result.IsCommentSet(result)) { + result.UnsetComment(result, iterator.pos.buffer_pos); + } else { + result.AddRowInternal(); + } lines_read++; } @@ -1196,7 +1328,11 @@ bool StringValueScanner::MoveToNextBuffer() { // we add the value result.AddValue(result, previous_buffer_handle->actual_size); // And an extra empty value to represent what comes after the delimiter - result.AddRow(result, previous_buffer_handle->actual_size); + if (result.IsCommentSet(result)) { + result.UnsetComment(result, iterator.pos.buffer_pos); + } else { + result.AddRow(result, previous_buffer_handle->actual_size); + } lines_read++; } else if (states.IsQuotedCurrent()) { // Unterminated quote @@ -1206,7 +1342,11 @@ bool StringValueScanner::MoveToNextBuffer() { result.current_line_position.end = current_line_start; result.InvalidState(result); } else { - result.AddRow(result, previous_buffer_handle->actual_size); + if (result.IsCommentSet(result)) { + result.UnsetComment(result, iterator.pos.buffer_pos); + } else { + result.AddRow(result, previous_buffer_handle->actual_size); + } lines_read++; } return false; @@ -1224,13 +1364,31 @@ bool StringValueScanner::MoveToNextBuffer() { return false; } -void StringValueScanner::SkipBOM() { - if (cur_buffer_handle->actual_size >= 3 && result.buffer_ptr[0] == '\xEF' && result.buffer_ptr[1] == '\xBB' && - result.buffer_ptr[2] == '\xBF') { +void StringValueResult::SkipBOM() const { + if (buffer_size >= 3 && buffer_ptr[0] == '\xEF' && buffer_ptr[1] == '\xBB' && buffer_ptr[2] == '\xBF' && + iterator.pos.buffer_pos == 0) { iterator.pos.buffer_pos = 3; } } +void StringValueResult::RemoveLastLine() { + // potentially de-nullify values + for (idx_t i = 0; i < chunk_col_id; i++) { + validity_mask[i]->SetValid(number_of_rows); + } + // reset column trackers + cur_col_id = 0; + chunk_col_id = 0; + // decrement row counter + number_of_rows--; +} +bool StringValueResult::PrintErrorLine() const { + // To print a lint, result size must be different, than one (i.e., this is a SetStart() trying to figure out new + // lines) And must either not be ignoring errors OR must be storing them in a rejects table. + return result_size != 1 && + (state_machine.options.store_rejects.GetValue() || !state_machine.options.ignore_errors.GetValue()); +} + void StringValueScanner::SkipUntilNewLine() { // Now skip until next newline if (state_machine->options.dialect_options.state_machine_options.new_line.GetValue() == @@ -1261,7 +1419,7 @@ void StringValueScanner::SkipUntilNewLine() { } } -bool StringValueScanner::CanDirectlyCast(const LogicalType &type) { +bool StringValueScanner::CanDirectlyCast(const LogicalType &type, bool icu_loaded) { switch (type.id()) { case LogicalTypeId::TINYINT: @@ -1280,29 +1438,33 @@ bool StringValueScanner::CanDirectlyCast(const LogicalType &type) { case LogicalTypeId::DECIMAL: case LogicalType::VARCHAR: return true; + case LogicalType::TIMESTAMP_TZ: + // We only try to do direct cast of timestamp tz if the ICU extension is not loaded, otherwise, it needs to go + // through string -> timestamp_tz casting + return !icu_loaded; default: return false; } } void StringValueScanner::SetStart() { - if (iterator.pos.buffer_idx == 0 && iterator.pos.buffer_pos == 0) { - // This means this is the very first buffer - // This CSV is not from auto-detect, so we don't know where exactly it starts - // Hence we potentially have to skip empty lines and headers. - SkipBOM(); - SkipCSVRows(state_machine->dialect_options.skip_rows.GetValue() + - state_machine->dialect_options.header.GetValue()); + if (iterator.first_one) { if (result.store_line_size) { result.error_handler.NewMaxLineSize(iterator.pos.buffer_pos); } return; } + if (state_machine->options.IgnoreErrors()) { + // If we are ignoring errors we don't really need to figure out a line. + return; + } + // The result size of the data after skipping the row is one line // We have to look for a new line that fits our schema // 1. We walk until the next new line bool line_found; unique_ptr scan_finder; do { + constexpr idx_t result_size = 1; SkipUntilNewLine(); if (state_machine->options.null_padding) { // When Null Padding, we assume we start from the correct new-line @@ -1310,11 +1472,12 @@ void StringValueScanner::SetStart() { } scan_finder = make_uniq(0U, buffer_manager, state_machine, make_shared_ptr(true), - csv_file_scan, false, iterator, true); + csv_file_scan, false, iterator, result_size); auto &tuples = scan_finder->ParseChunk(); line_found = true; if (tuples.number_of_rows != 1 || - (!tuples.borked_rows.empty() && !state_machine->options.ignore_errors.GetValue())) { + (!tuples.borked_rows.empty() && !state_machine->options.ignore_errors.GetValue()) || + tuples.first_line_is_comment) { line_found = false; // If no tuples were parsed, this is not the correct start, we need to skip until the next new line // Or if columns don't match, this is not the correct start, we need to skip until the next new line @@ -1358,7 +1521,7 @@ void StringValueScanner::FinalizeChunkProcess() { // If we are not done we have two options. // 1) If a boundary is set. if (iterator.IsBoundarySet()) { - if (!result.current_errors.HasErrorType(CSVErrorType::UNTERMINATED_QUOTES)) { + if (!result.current_errors.HasErrorType(UNTERMINATED_QUOTES)) { iterator.done = true; } // We read until the next line or until we have nothing else to read. @@ -1368,7 +1531,7 @@ void StringValueScanner::FinalizeChunkProcess() { } bool moved = MoveToNextBuffer(); if (cur_buffer_handle) { - if (moved && result.cur_col_id < result.number_of_columns && result.cur_col_id > 0) { + if (moved && result.cur_col_id > 0) { ProcessExtraRow(); } else if (!moved) { ProcessExtraRow(); @@ -1398,7 +1561,7 @@ void StringValueScanner::FinalizeChunkProcess() { } } iterator.done = FinishedFile(); - if (result.null_padding && result.number_of_rows < STANDARD_VECTOR_SIZE) { + if (result.null_padding && result.number_of_rows < STANDARD_VECTOR_SIZE && result.chunk_col_id > 0) { while (result.chunk_col_id < result.parse_chunk.ColumnCount()) { result.validity_mask[result.chunk_col_id++]->SetInvalid(result.number_of_rows); result.cur_col_id++; diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index da1c93ce..96aa9e2d 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -1,11 +1,12 @@ #include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" +#include "duckdb/common/types/value.hpp" namespace duckdb { CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, - CSVStateMachineCache &state_machine_cache_p, SetColumns set_columns_p) + CSVStateMachineCache &state_machine_cache_p, bool default_null_to_varchar_p) : state_machine_cache(state_machine_cache_p), options(options_p), buffer_manager(std::move(buffer_manager_p)), - set_columns(set_columns_p) { + default_null_to_varchar(default_null_to_varchar_p) { // Initialize Format Candidates for (const auto &format_template : format_template_candidates) { auto &logical_type = format_template.first; @@ -15,16 +16,19 @@ CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr max_columns_found = set_columns.Size(); error_handler = make_shared_ptr(options.ignore_errors.GetValue()); detection_error_handler = make_shared_ptr(true); + if (options.columns_set) { + set_columns = SetColumns(&options.sql_type_list, &options.name_list); + } } -bool SetColumns::IsSet() { +bool SetColumns::IsSet() const { if (!types) { return false; } return !types->empty(); } -idx_t SetColumns::Size() { +idx_t SetColumns::Size() const { if (!types) { return 0; } @@ -57,6 +61,7 @@ void MatchAndRepaceUserSetVariables(DialectOptions &original, DialectOptions &sn error); MatchAndReplace(original.state_machine_options.quote, sniffed.state_machine_options.quote, "Quote", error); MatchAndReplace(original.state_machine_options.escape, sniffed.state_machine_options.escape, "Escape", error); + MatchAndReplace(original.state_machine_options.comment, sniffed.state_machine_options.comment, "Comment", error); if (found_date) { MatchAndReplace(original.date_format[LogicalTypeId::DATE], sniffed.date_format[LogicalTypeId::DATE], "Date Format", error); @@ -80,8 +85,99 @@ void CSVSniffer::SetResultOptions() { MatchAndRepaceUserSetVariables(options.dialect_options, best_candidate->GetStateMachine().dialect_options, options.sniffer_user_mismatch_error, found_date, found_timestamp); options.dialect_options.num_cols = best_candidate->GetStateMachine().dialect_options.num_cols; + options.dialect_options.rows_until_header = best_candidate->GetStateMachine().dialect_options.rows_until_header; } +SnifferResult CSVSniffer::MinimalSniff() { + if (set_columns.IsSet()) { + // Nothing to see here + return SnifferResult(*set_columns.types, *set_columns.names); + } + // Return Types detected + vector return_types; + // Column Names detected + vector names; + + buffer_manager->sniffing = true; + constexpr idx_t result_size = 2; + + auto state_machine = + make_shared_ptr(options, options.dialect_options.state_machine_options, state_machine_cache); + ColumnCountScanner count_scanner(buffer_manager, state_machine, error_handler, result_size); + auto &sniffed_column_counts = count_scanner.ParseChunk(); + if (sniffed_column_counts.result_position == 0) { + return {{}, {}}; + } + + state_machine->dialect_options.num_cols = sniffed_column_counts[0].number_of_columns; + options.dialect_options.num_cols = sniffed_column_counts[0].number_of_columns; + + // First figure out the number of columns on this configuration + auto scanner = count_scanner.UpgradeToStringValueScanner(); + // Parse chunk and read csv with info candidate + auto &data_chunk = scanner->ParseChunk().ToChunk(); + idx_t start_row = 0; + if (sniffed_column_counts.result_position == 2) { + // If equal to two, we will only use the second row for type checking + start_row = 1; + } + + // Gather Types + for (idx_t i = 0; i < state_machine->dialect_options.num_cols; i++) { + best_sql_types_candidates_per_column_idx[i] = state_machine->options.auto_type_candidates; + } + SniffTypes(data_chunk, *state_machine, best_sql_types_candidates_per_column_idx, start_row); + + // Possibly Gather Header + vector potential_header; + if (start_row != 0) { + for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { + auto &cur_vector = data_chunk.data[col_idx]; + auto vector_data = FlatVector::GetData(cur_vector); + HeaderValue val(vector_data[0]); + potential_header.emplace_back(val); + } + } + names = DetectHeaderInternal(buffer_manager->context, potential_header, *state_machine, set_columns, + best_sql_types_candidates_per_column_idx, options, *error_handler); + + for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { + LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); + if (best_sql_types_candidates_per_column_idx[column_idx].size() == options.auto_type_candidates.size()) { + d_type = LogicalType::VARCHAR; + } + detected_types.push_back(d_type); + } + + return {detected_types, names}; +} + +SnifferResult CSVSniffer::AdaptiveSniff(CSVSchema &file_schema) { + auto min_sniff_res = MinimalSniff(); + bool run_full = error_handler->AnyErrors() || detection_error_handler->AnyErrors(); + // Check if we are happy with the result or if we need to do more sniffing + if (!error_handler->AnyErrors() && !detection_error_handler->AnyErrors()) { + // If we got no errors, we also run full if schemas do not match. + if (!set_columns.IsSet() && !options.file_options.AnySet()) { + string error; + run_full = + !file_schema.SchemasMatch(error, min_sniff_res.names, min_sniff_res.return_types, options.file_path); + } + } + if (run_full) { + // We run full sniffer + auto full_sniffer = SniffCSV(); + if (!set_columns.IsSet() && !options.file_options.AnySet()) { + string error; + if (!file_schema.SchemasMatch(error, full_sniffer.names, full_sniffer.return_types, options.file_path) && + !options.ignore_errors.GetValue()) { + throw InvalidInputException(error); + } + } + return full_sniffer; + } + return min_sniff_res; +} SnifferResult CSVSniffer::SniffCSV(bool force_match) { buffer_manager->sniffing = true; // 1. Dialect Detection @@ -98,21 +194,19 @@ SnifferResult CSVSniffer::SniffCSV(bool force_match) { // We reset the buffer for compressed files // This is done because we can't easily seek on compressed files, if a buffer goes out of scope we must read from // the start - if (!buffer_manager->file_handle->uncompressed) { + if (buffer_manager->file_handle->compression_type != FileCompressionType::UNCOMPRESSED) { buffer_manager->ResetBufferManager(); } buffer_manager->sniffing = false; if (!best_candidate->error_handler->errors.empty() && !options.ignore_errors.GetValue()) { for (auto &error_vector : best_candidate->error_handler->errors) { for (auto &error : error_vector.second) { - if (error.type == CSVErrorType::MAXIMUM_LINE_SIZE) { + if (error.type == MAXIMUM_LINE_SIZE) { // If it's a maximum line size error, we can do it now. error_handler->Error(error); } } } - auto error = CSVError::SniffingError(options.file_path); - error_handler->Error(error); } D_ASSERT(best_sql_types_candidates_per_column_idx.size() == names.size()); // We are done, Set the CSV Options in the reference. Construct and return the result. @@ -142,7 +236,7 @@ SnifferResult CSVSniffer::SniffCSV(bool force_match) { string type_error = "The Column types set by the user do not match the ones found by the sniffer. \n"; auto &set_types = *set_columns.types; for (idx_t i = 0; i < set_columns.Size(); i++) { - if (set_types[i] != detected_types[i] && !(set_types[i].IsNumeric() && detected_types[i].IsNumeric())) { + if (set_types[i] != detected_types[i]) { type_error += "Column at position: " + to_string(i) + " Set type: " + set_types[i].ToString() + " Sniffed type: " + detected_types[i].ToString() + "\n"; detected_types[i] = set_types[i]; @@ -158,13 +252,14 @@ SnifferResult CSVSniffer::SniffCSV(bool force_match) { throw InvalidInputException(error); } options.was_type_manually_set = manually_set; - // We do not need to run type refinement, since the types have been given by the user - return SnifferResult({}, {}); } if (!error.empty() && force_match) { throw InvalidInputException(error); } options.was_type_manually_set = manually_set; + if (set_columns.IsSet()) { + return SnifferResult(*set_columns.types, *set_columns.names); + } return SnifferResult(detected_types, names); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index b167cbed..44d17909 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -1,9 +1,12 @@ +#include "duckdb/common/shared_ptr.hpp" #include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" #include "duckdb/main/client_data.hpp" -#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" namespace duckdb { +constexpr idx_t CSVReaderOptions::sniff_size; + bool IsQuoteDefault(char quote) { if (quote == '\"' || quote == '\'' || quote == '\0') { return true; @@ -11,56 +14,129 @@ bool IsQuoteDefault(char quote) { return false; } -void CSVSniffer::GenerateCandidateDetectionSearchSpace(vector &delim_candidates, - vector "erule_candidates, - unordered_map> "e_candidates_map, - unordered_map> &escape_candidates_map) { - if (options.dialect_options.state_machine_options.delimiter.IsSetByUser()) { +vector DialectCandidates::GetDefaultDelimiter() { + return {',', '|', ';', '\t'}; +} + +vector> DialectCandidates::GetDefaultQuote() { + return {{'\"'}, {'\"', '\''}, {'\0'}}; +} + +vector DialectCandidates::GetDefaultQuoteRule() { + return {QuoteRule::QUOTES_RFC, QuoteRule::QUOTES_OTHER, QuoteRule::NO_QUOTES}; +} + +vector> DialectCandidates::GetDefaultEscape() { + return {{'\"', '\0', '\''}, {'\\'}, {'\0'}}; +} + +vector DialectCandidates::GetDefaultComment() { + return {'#', '\0'}; +} + +string DialectCandidates::Print() { + std::ostringstream search_space; + + search_space << "Delimiter Candidates: "; + for (idx_t i = 0; i < delim_candidates.size(); i++) { + search_space << "\'" << delim_candidates[i] << "\'"; + if (i < delim_candidates.size() - 1) { + search_space << ", "; + } + } + search_space << "\n"; + search_space << "Quote/Escape Candidates: "; + for (uint8_t i = 0; i < static_cast(quoterule_candidates.size()); i++) { + auto quote_candidate = quote_candidates_map[i]; + auto escape_candidate = escape_candidates_map[i]; + for (idx_t j = 0; j < quote_candidate.size(); j++) { + for (idx_t k = 0; k < escape_candidate.size(); k++) { + search_space << "[\'" << quote_candidate[j] << "\',\'" << escape_candidate[k] << "\']"; + if (k < escape_candidate.size() - 1) { + search_space << ","; + } + } + if (j < quote_candidate.size() - 1) { + search_space << ","; + } + } + if (i < quoterule_candidates.size() - 1) { + search_space << ","; + } + } + search_space << "\n"; + + search_space << "Comment Candidates: "; + for (idx_t i = 0; i < comment_candidates.size(); i++) { + search_space << "\'" << comment_candidates[i] << "\'"; + if (i < comment_candidates.size() - 1) { + search_space << ", "; + } + } + search_space << "\n"; + + return search_space.str(); +} + +DialectCandidates::DialectCandidates(const CSVStateMachineOptions &options) { + // assert that quotes escapes and rules have equal size + auto default_quote = GetDefaultQuote(); + auto default_escape = GetDefaultEscape(); + auto default_quote_rule = GetDefaultQuoteRule(); + auto default_delimiter = GetDefaultDelimiter(); + auto default_comment = GetDefaultComment(); + + D_ASSERT(default_quote.size() == default_quote_rule.size() && default_quote_rule.size() == default_escape.size()); + // fill the escapes + for (idx_t i = 0; i < default_quote_rule.size(); i++) { + escape_candidates_map[static_cast(default_quote_rule[i])] = default_escape[i]; + } + + if (options.delimiter.IsSetByUser()) { // user provided a delimiter: use that delimiter - delim_candidates = {options.dialect_options.state_machine_options.delimiter.GetValue()}; + delim_candidates = {options.delimiter.GetValue()}; } else { // no delimiter provided: try standard/common delimiters - delim_candidates = {',', '|', ';', '\t'}; + delim_candidates = default_delimiter; + } + if (options.comment.IsSetByUser()) { + // user provided comment character: use that as a comment + comment_candidates = {options.comment.GetValue()}; + } else { + // no comment provided: try standard/common comments + comment_candidates = default_comment; } - if (options.dialect_options.state_machine_options.quote.IsSetByUser()) { + if (options.quote.IsSetByUser()) { // user provided quote: use that quote rule - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = { - options.dialect_options.state_machine_options.quote.GetValue()}; - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = { - options.dialect_options.state_machine_options.quote.GetValue()}; - quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = { - options.dialect_options.state_machine_options.quote.GetValue()}; + for (auto "e_rule : default_quote_rule) { + quote_candidates_map[static_cast(quote_rule)] = {options.quote.GetValue()}; + } // also add it as a escape rule - if (!IsQuoteDefault(options.dialect_options.state_machine_options.quote.GetValue())) { - escape_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC].emplace_back( - options.dialect_options.state_machine_options.quote.GetValue()); + if (!IsQuoteDefault(options.quote.GetValue())) { + escape_candidates_map[static_cast(QuoteRule::QUOTES_RFC)].emplace_back(options.quote.GetValue()); } } else { // no quote rule provided: use standard/common quotes - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\"'}; - quote_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\"', '\''}; - quote_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; + for (idx_t i = 0; i < default_quote_rule.size(); i++) { + quote_candidates_map[static_cast(default_quote_rule[i])] = {default_quote[i]}; + } } - if (options.dialect_options.state_machine_options.escape.IsSetByUser()) { + if (options.escape.IsSetByUser()) { // user provided escape: use that escape rule - if (options.dialect_options.state_machine_options.escape == '\0') { + if (options.escape == '\0') { quoterule_candidates = {QuoteRule::QUOTES_RFC}; } else { quoterule_candidates = {QuoteRule::QUOTES_OTHER}; } - escape_candidates_map[(uint8_t)quoterule_candidates[0]] = { - options.dialect_options.state_machine_options.escape.GetValue()}; + escape_candidates_map[static_cast(quoterule_candidates[0])] = {options.escape.GetValue()}; } else { // no escape provided: try standard/common escapes - quoterule_candidates = {QuoteRule::QUOTES_RFC, QuoteRule::QUOTES_OTHER, QuoteRule::NO_QUOTES}; + quoterule_candidates = default_quote_rule; } } void CSVSniffer::GenerateStateMachineSearchSpace(vector> &column_count_scanners, - const vector &delimiter_candidates, - const vector "erule_candidates, - const unordered_map> "e_candidates_map, - const unordered_map> &escape_candidates_map) { + const DialectCandidates &dialect_candidates) { // Generate state machines for all option combinations NewLineIdentifier new_line_id; if (options.dialect_options.state_machine_options.new_line.IsSetByUser()) { @@ -68,74 +144,159 @@ void CSVSniffer::GenerateStateMachineSearchSpace(vector(quoterule)); for (const auto "e : quote_candidates) { - for (const auto &delimiter : delimiter_candidates) { - const auto &escape_candidates = escape_candidates_map.at((uint8_t)quoterule); + for (const auto &delimiter : dialect_candidates.delim_candidates) { + const auto &escape_candidates = + dialect_candidates.escape_candidates_map.at(static_cast(quoterule)); for (const auto &escape : escape_candidates) { - D_ASSERT(buffer_manager); - CSVStateMachineOptions state_machine_options(delimiter, quote, escape, new_line_id); - auto sniffing_state_machine = - make_uniq(options, state_machine_options, state_machine_cache); - column_count_scanners.emplace_back(make_uniq( - buffer_manager, std::move(sniffing_state_machine), detection_error_handler)); + for (const auto &comment : dialect_candidates.comment_candidates) { + D_ASSERT(buffer_manager); + CSVStateMachineOptions state_machine_options(delimiter, quote, escape, comment, new_line_id); + auto sniffing_state_machine = + make_shared_ptr(options, state_machine_options, state_machine_cache); + if (options.dialect_options.skip_rows.IsSetByUser()) { + if (!iterator_set) { + first_iterator = BaseScanner::SkipCSVRows(buffer_manager, sniffing_state_machine, + options.dialect_options.skip_rows.GetValue()); + iterator_set = true; + } + column_count_scanners.emplace_back(make_uniq( + buffer_manager, std::move(sniffing_state_machine), detection_error_handler, + CSVReaderOptions::sniff_size, first_iterator)); + continue; + } + column_count_scanners.emplace_back( + make_uniq(buffer_manager, std::move(sniffing_state_machine), + detection_error_handler, CSVReaderOptions::sniff_size)); + } } } } } } +// Returns true if a comment is acceptable +bool AreCommentsAcceptable(const ColumnCountResult &result, idx_t num_cols, bool comment_set_by_user) { + // For a comment to be acceptable, we want 3/5th's majority of unmatches in the columns + constexpr double min_majority = 0.6; + // detected comments, are all lines that started with a comment character. + double detected_comments = 0; + // If at least one comment is a full line comment + bool has_full_line_comment = false; + // valid comments are all lines where the number of columns does not fit our expected number of columns. + double valid_comments = 0; + for (idx_t i = 0; i < result.result_position; i++) { + if (result.column_counts[i].is_comment || result.column_counts[i].is_mid_comment) { + detected_comments++; + if (result.column_counts[i].number_of_columns != num_cols && result.column_counts[i].is_comment) { + has_full_line_comment = true; + valid_comments++; + } + if (result.column_counts[i].number_of_columns == num_cols && result.column_counts[i].is_mid_comment) { + valid_comments++; + } + } + } + // If we do not encounter at least one full line comment, we do not consider this comment option. + if (valid_comments == 0 || (!has_full_line_comment && !comment_set_by_user)) { + // this is only valid if our comment character is \0 + if (result.state_machine.state_machine_options.comment.GetValue() == '\0') { + return true; + } + return false; + } + + return valid_comments / detected_comments >= min_majority; +} + void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, idx_t &rows_read, - idx_t &best_consistent_rows, idx_t &prev_padding_count) { + idx_t &best_consistent_rows, idx_t &prev_padding_count, + idx_t &min_ignored_rows) { // The sniffed_column_counts variable keeps track of the number of columns found for each row auto &sniffed_column_counts = scanner->ParseChunk(); idx_t dirty_notes = 0; + idx_t dirty_notes_minus_comments = 0; if (sniffed_column_counts.error) { // This candidate has an error (i.e., over maximum line size or never unquoting quoted values) return; } idx_t consistent_rows = 0; - idx_t num_cols = sniffed_column_counts.result_position == 0 ? 1 : sniffed_column_counts[0]; + idx_t num_cols = sniffed_column_counts.result_position == 0 ? 1 : sniffed_column_counts[0].number_of_columns; idx_t padding_count = 0; + idx_t comment_rows = 0; + idx_t ignored_rows = 0; bool allow_padding = options.null_padding; + bool first_valid = false; if (sniffed_column_counts.result_position > rows_read) { rows_read = sniffed_column_counts.result_position; } if (set_columns.IsCandidateUnacceptable(num_cols, options.null_padding, options.ignore_errors.GetValue(), - sniffed_column_counts.last_value_always_empty)) { + sniffed_column_counts[0].last_value_always_empty)) { // Not acceptable return; } + idx_t header_idx = 0; for (idx_t row = 0; row < sniffed_column_counts.result_position; row++) { - if (set_columns.IsCandidateUnacceptable(sniffed_column_counts[row], options.null_padding, + if (set_columns.IsCandidateUnacceptable(sniffed_column_counts[row].number_of_columns, options.null_padding, options.ignore_errors.GetValue(), - sniffed_column_counts.last_value_always_empty)) { + sniffed_column_counts[row].last_value_always_empty)) { // Not acceptable return; } - if (sniffed_column_counts[row] == num_cols || (options.ignore_errors.GetValue() && !options.null_padding)) { + if (sniffed_column_counts[row].is_comment) { + comment_rows++; + } else if (sniffed_column_counts[row].last_value_always_empty && + sniffed_column_counts[row].number_of_columns == + sniffed_column_counts[header_idx].number_of_columns + 1) { + // we allow for the first row to miss one column IF last_value_always_empty is true + // This is so we can sniff files that have an extra delimiter on the data part. + // e.g., C1|C2\n1|2|\n3|4| consistent_rows++; - } else if (num_cols < sniffed_column_counts[row] && !options.dialect_options.skip_rows.IsSetByUser() && + } else if (num_cols < sniffed_column_counts[row].number_of_columns && + (!options.dialect_options.skip_rows.IsSetByUser() || comment_rows > 0) && (!set_columns.IsSet() || options.null_padding)) { // all rows up to this point will need padding + if (!first_valid) { + first_valid = true; + sniffed_column_counts.state_machine.dialect_options.rows_until_header = row; + } padding_count = 0; // we use the maximum amount of num_cols that we find - num_cols = sniffed_column_counts[row]; + num_cols = sniffed_column_counts[row].number_of_columns; dirty_notes = row; + // sniffed_column_counts.state_machine.dialect_options.rows_until_header = dirty_notes; + dirty_notes_minus_comments = dirty_notes - comment_rows; + header_idx = row; consistent_rows = 1; - - } else if (num_cols >= sniffed_column_counts[row]) { + } else if (sniffed_column_counts[row].number_of_columns == num_cols || + (options.ignore_errors.GetValue() && !options.null_padding)) { + if (!first_valid) { + first_valid = true; + sniffed_column_counts.state_machine.dialect_options.rows_until_header = row; + } + if (sniffed_column_counts[row].number_of_columns != num_cols) { + ignored_rows++; + } + consistent_rows++; + } else if (num_cols >= sniffed_column_counts[row].number_of_columns) { // we are missing some columns, we can parse this as long as we add padding padding_count++; } } + if (sniffed_column_counts.state_machine.options.dialect_options.skip_rows.IsSetByUser()) { + sniffed_column_counts.state_machine.dialect_options.rows_until_header += + sniffed_column_counts.state_machine.options.dialect_options.skip_rows.GetValue(); + } // Calculate the total number of consistent rows after adding padding. consistent_rows += padding_count; // Whether there are more values (rows) available that are consistent, exceeding the current best. - bool more_values = (consistent_rows > best_consistent_rows && num_cols >= max_columns_found); + bool more_values = consistent_rows > best_consistent_rows && num_cols >= max_columns_found; // If additional padding is required when compared to the previous padding count. bool require_more_padding = padding_count > prev_padding_count; @@ -148,21 +309,25 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, // If the number of rows is consistent with the calculated value after accounting for skipped rows and the // start row. - bool rows_consistent = consistent_rows + (dirty_notes - options.dialect_options.skip_rows.GetValue()) == - sniffed_column_counts.result_position - options.dialect_options.skip_rows.GetValue(); + bool rows_consistent = + consistent_rows + (dirty_notes_minus_comments - options.dialect_options.skip_rows.GetValue()) + comment_rows == + sniffed_column_counts.result_position - options.dialect_options.skip_rows.GetValue(); // If there are more than one consistent row. - bool more_than_one_row = (consistent_rows > 1); + bool more_than_one_row = consistent_rows > 1; // If there are more than one column. - bool more_than_one_column = (num_cols > 1); + bool more_than_one_column = num_cols > 1; // If the start position is valid. bool start_good = !candidates.empty() && - (dirty_notes <= candidates.front()->GetStateMachine().dialect_options.skip_rows.GetValue()); + dirty_notes <= candidates.front()->GetStateMachine().dialect_options.skip_rows.GetValue(); // If padding happened but it is not allowed. bool invalid_padding = !allow_padding && padding_count > 0; + bool comments_are_acceptable = AreCommentsAcceptable( + sniffed_column_counts, num_cols, options.dialect_options.state_machine_options.comment.IsSetByUser()); + // If rows are consistent and no invalid padding happens, this is the best suitable candidate if one of the // following is valid: // - There's a single column before. @@ -171,7 +336,7 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, if (rows_consistent && (single_column_before || (more_values && !require_more_padding) || (more_than_one_column && require_less_padding)) && - !invalid_padding) { + !invalid_padding && comments_are_acceptable) { if (!candidates.empty() && set_columns.IsSet() && max_columns_found == candidates.size()) { // We have a candidate that fits our requirements better return; @@ -182,18 +347,23 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, // Give preference to quoted boys. return; } + if (max_columns_found == num_cols && ignored_rows > min_ignored_rows) { + return; + } best_consistent_rows = consistent_rows; max_columns_found = num_cols; prev_padding_count = padding_count; + min_ignored_rows = ignored_rows; + if (options.dialect_options.skip_rows.IsSetByUser()) { // If skip rows is set by user, and we found dirty notes, we only accept it if either null_padding or - // ignore_errors is set - if (dirty_notes != 0 && !options.null_padding && !options.ignore_errors.GetValue()) { + // ignore_errors is set we have comments + if (dirty_notes != 0 && !options.null_padding && !options.ignore_errors.GetValue() && comment_rows == 0) { return; } sniffing_state_machine.dialect_options.skip_rows = options.dialect_options.skip_rows.GetValue(); - } else if (!options.null_padding && !options.ignore_errors.GetValue()) { + } else if (!options.null_padding) { sniffing_state_machine.dialect_options.skip_rows = dirty_notes; } @@ -206,7 +376,7 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, // no additional padding is required, and there is no invalid padding, and there is not yet a candidate // with the same quote, we add this state_machine as a suitable candidate. if (more_than_one_row && more_than_one_column && start_good && rows_consistent && !require_more_padding && - !invalid_padding && num_cols == max_columns_found) { + !invalid_padding && num_cols == max_columns_found && comments_are_acceptable) { auto &sniffing_state_machine = scanner->GetStateMachine(); bool same_quote_is_candidate = false; @@ -224,7 +394,7 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, return; } sniffing_state_machine.dialect_options.skip_rows = options.dialect_options.skip_rows.GetValue(); - } else if (!options.null_padding && !options.ignore_errors.GetValue()) { + } else if (!options.null_padding) { sniffing_state_machine.dialect_options.skip_rows = dirty_notes; } @@ -238,14 +408,13 @@ bool CSVSniffer::RefineCandidateNextChunk(ColumnCountScanner &candidate) { auto &sniffed_column_counts = candidate.ParseChunk(); for (idx_t i = 0; i < sniffed_column_counts.result_position; i++) { if (set_columns.IsSet()) { - return !set_columns.IsCandidateUnacceptable(sniffed_column_counts[i], options.null_padding, - options.ignore_errors.GetValue(), - sniffed_column_counts.last_value_always_empty); - } else { - if (max_columns_found != sniffed_column_counts[i] && - (!options.null_padding && !options.ignore_errors.GetValue())) { - return false; - } + return !set_columns.IsCandidateUnacceptable(sniffed_column_counts[i].number_of_columns, + options.null_padding, options.ignore_errors.GetValue(), + sniffed_column_counts[i].last_value_always_empty); + } + if (max_columns_found != sniffed_column_counts[i].number_of_columns && + (!options.null_padding && !options.ignore_errors.GetValue() && !sniffed_column_counts[i].is_comment)) { + return false; } } return true; @@ -281,9 +450,8 @@ void CSVSniffer::RefineCandidates() { // that have actually quoted values, otherwise we will choose quotes = \0 candidates.clear(); if (!successful_candidates.empty()) { - unique_ptr cc_best_candidate; for (idx_t i = 0; i < successful_candidates.size(); i++) { - cc_best_candidate = std::move(successful_candidates[i]); + unique_ptr cc_best_candidate = std::move(successful_candidates[i]); if (cc_best_candidate->state_machine->state_machine_options.quote != '\0' && cc_best_candidate->ever_quoted) { candidates.clear(); @@ -292,9 +460,7 @@ void CSVSniffer::RefineCandidates() { } candidates.push_back(std::move(cc_best_candidate)); } - return; } - return; } NewLineIdentifier CSVSniffer::DetectNewLineDelimiter(CSVBufferManager &buffer_manager) { @@ -316,25 +482,10 @@ NewLineIdentifier CSVSniffer::DetectNewLineDelimiter(CSVBufferManager &buffer_ma if (carriage_return && n) { return NewLineIdentifier::CARRY_ON; } - return NewLineIdentifier::SINGLE; -} - -void CSVSniffer::SkipLines(vector> &csv_state_machines) { - if (csv_state_machines.empty()) { - return; - } - auto &first_scanner = *csv_state_machines[0]; - // We figure out the iterator position for the first scanner - if (options.dialect_options.skip_rows.IsSetByUser()) { - first_scanner.SkipCSVRows(options.dialect_options.skip_rows.GetValue()); - } - // The iterator position is the same regardless of the scanner configuration, hence we apply the same iterator - // To the remaining scanners - const auto first_iterator = first_scanner.GetIterator(); - for (idx_t i = 1; i < csv_state_machines.size(); i++) { - auto &cur_scanner = *csv_state_machines[i]; - cur_scanner.SetIterator(first_iterator); + if (carriage_return) { + return NewLineIdentifier::SINGLE_R; } + return NewLineIdentifier::SINGLE_N; } // Dialect Detection consists of five steps: @@ -344,43 +495,30 @@ void CSVSniffer::SkipLines(vector> &csv_state_mac // 4. Analyze the remaining chunks of the file and find the best dialect candidate void CSVSniffer::DetectDialect() { // Variables for Dialect Detection - // Candidates for the delimiter - vector delim_candidates; - // Quote-Rule Candidates - vector quoterule_candidates; - // Candidates for the quote option - unordered_map> quote_candidates_map; - // Candidates for the escape option - unordered_map> escape_candidates_map; - escape_candidates_map[(uint8_t)QuoteRule::QUOTES_RFC] = {'\"', '\'', '\0'}; - escape_candidates_map[(uint8_t)QuoteRule::QUOTES_OTHER] = {'\\'}; - escape_candidates_map[(uint8_t)QuoteRule::NO_QUOTES] = {'\0'}; + DialectCandidates dialect_candidates(options.dialect_options.state_machine_options); // Number of rows read idx_t rows_read = 0; // Best Number of consistent rows (i.e., presenting all columns) idx_t best_consistent_rows = 0; // If padding was necessary (i.e., rows are missing some columns, how many) idx_t prev_padding_count = 0; + // Min number of ignores rows + idx_t best_ignored_rows = 0; // Vector of CSV State Machines vector> csv_state_machines; - - // Step 1: Generate search space - GenerateCandidateDetectionSearchSpace(delim_candidates, quoterule_candidates, quote_candidates_map, - escape_candidates_map); - // Step 2: Generate state machines - GenerateStateMachineSearchSpace(csv_state_machines, delim_candidates, quoterule_candidates, quote_candidates_map, - escape_candidates_map); - SkipLines(csv_state_machines); - // Step 3: Analyze all candidates on the first chunk + // Step 1: Generate state machines + GenerateStateMachineSearchSpace(csv_state_machines, dialect_candidates); + // Step 2: Analyze all candidates on the first chunk for (auto &state_machine : csv_state_machines) { - AnalyzeDialectCandidate(std::move(state_machine), rows_read, best_consistent_rows, prev_padding_count); + AnalyzeDialectCandidate(std::move(state_machine), rows_read, best_consistent_rows, prev_padding_count, + best_ignored_rows); } - // Step 4: Loop over candidates and find if they can still produce good results for the remaining chunks + // Step 3: Loop over candidates and find if they can still produce good results for the remaining chunks RefineCandidates(); // if no dialect candidate was found, we throw an exception if (candidates.empty()) { - auto error = CSVError::SniffingError(options.file_path); + auto error = CSVError::SniffingError(options, dialect_candidates.Print()); error_handler->Error(error); } } diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp index e5498a57..c3b7d59c 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp @@ -6,7 +6,6 @@ #include "utf8proc.hpp" namespace duckdb { - // Helper function to generate column names static string GenerateColumnName(const idx_t total_cols, const idx_t col_number, const string &prefix = "column") { auto max_digits = NumericHelper::UnsignedLength(total_cols - 1); @@ -98,26 +97,33 @@ static string NormalizeColumnName(const string &col_name) { } // If our columns were set by the user, we verify if their names match with the first row -bool CSVSniffer::DetectHeaderWithSetColumn() { +bool CSVSniffer::DetectHeaderWithSetColumn(ClientContext &context, vector &best_header_row, + SetColumns &set_columns, CSVReaderOptions &options) { bool has_header = true; bool all_varchar = true; bool first_row_consistent = true; + std::ostringstream error; // User set the names, we must check if they match the first row // We do a +1 to check for situations where the csv file has an extra all null column if (set_columns.Size() != best_header_row.size() && set_columns.Size() + 1 != best_header_row.size()) { return false; - } else { - // Let's do a match-aroo - for (idx_t i = 0; i < set_columns.Size(); i++) { - if (best_header_row[i].IsNull()) { - return false; - } - if (best_header_row[i].value.GetString() != (*set_columns.names)[i]) { - has_header = false; - break; - } + } + + // Let's do a match-aroo + for (idx_t i = 0; i < set_columns.Size(); i++) { + if (best_header_row[i].IsNull()) { + return false; + } + if (best_header_row[i].value != (*set_columns.names)[i]) { + error << "Header Mismatch at position:" << i << "\n"; + error << "Expected Name: \"" << (*set_columns.names)[i] << "\"."; + error << "Actual Name: \"" << best_header_row[i].value << "\"." + << "\n"; + has_header = false; + break; } } + if (!has_header) { // We verify if the types are consistent for (idx_t col = 0; col < set_columns.Size(); col++) { @@ -125,32 +131,60 @@ bool CSVSniffer::DetectHeaderWithSetColumn() { const auto &sql_type = (*set_columns.types)[col]; if (sql_type != LogicalType::VARCHAR) { all_varchar = false; - if (!CanYouCastIt(best_header_row[col].value, sql_type, options.dialect_options, - best_header_row[col].IsNull(), options.decimal_separator[0])) { + if (!CSVSniffer::CanYouCastIt(context, best_header_row[col].value, sql_type, options.dialect_options, + best_header_row[col].IsNull(), options.decimal_separator[0])) { first_row_consistent = false; } } } + if (!first_row_consistent) { + options.sniffer_user_mismatch_error += error.str(); + } if (all_varchar) { - // Can't be the header - return false; + return true; } return !first_row_consistent; } return has_header; } -void CSVSniffer::DetectHeader() { - auto &sniffer_state_machine = best_candidate->GetStateMachine(); + +bool EmptyHeader(const string &col_name, bool is_null, bool normalize) { + if (col_name.empty() || is_null) { + return true; + } + if (normalize) { + // normalize has special logic to trim white spaces and generate names + return false; + } + // check if it's all white spaces + for (auto &c : col_name) { + if (!StringUtil::CharacterIsSpace(c)) { + return false; + } + } + // if we are not normalizing the name and is all white spaces, then we generate a name + return true; +} + +vector +CSVSniffer::DetectHeaderInternal(ClientContext &context, vector &best_header_row, + CSVStateMachine &state_machine, SetColumns &set_columns, + unordered_map> &best_sql_types_candidates_per_column_idx, + CSVReaderOptions &options, CSVErrorHandler &error_handler) { + vector detected_names; + auto &dialect_options = state_machine.dialect_options; if (best_header_row.empty()) { - sniffer_state_machine.dialect_options.header = false; - for (idx_t col = 0; col < sniffer_state_machine.dialect_options.num_cols; col++) { - names.push_back(GenerateColumnName(sniffer_state_machine.dialect_options.num_cols, col)); + dialect_options.header = false; + for (idx_t col = 0; col < dialect_options.num_cols; col++) { + detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); } // If the user provided names, we must replace our header with the user provided names - for (idx_t i = 0; i < MinValue(names.size(), sniffer_state_machine.options.name_list.size()); i++) { - names[i] = sniffer_state_machine.options.name_list[i]; + if (!options.columns_set) { + for (idx_t i = 0; i < MinValue(best_header_row.size(), options.name_list.size()); i++) { + detected_names[i] = options.name_list[i]; + } } - return; + return detected_names; } // information for header detection bool first_row_consistent = true; @@ -158,16 +192,17 @@ void CSVSniffer::DetectHeader() { bool first_row_nulls = true; // If null-padding is not allowed and there is a mismatch between our header candidate and the number of columns // We can't detect the dialect/type options properly - if (!sniffer_state_machine.options.null_padding && - best_sql_types_candidates_per_column_idx.size() != best_header_row.size()) { - auto error = CSVError::SniffingError(options.file_path); - error_handler->Error(error); + if (!options.null_padding && best_sql_types_candidates_per_column_idx.size() != best_header_row.size()) { + auto error = + CSVError::HeaderSniffingError(options, best_header_row, best_sql_types_candidates_per_column_idx.size(), + state_machine.dialect_options.state_machine_options.delimiter.GetValue()); + error_handler.Error(error); } bool all_varchar = true; bool has_header; if (set_columns.IsSet()) { - has_header = DetectHeaderWithSetColumn(); + has_header = DetectHeaderWithSetColumn(context, best_header_row, set_columns, options); } else { for (idx_t col = 0; col < best_header_row.size(); col++) { if (!best_header_row[col].IsNull()) { @@ -177,7 +212,7 @@ void CSVSniffer::DetectHeader() { const auto &sql_type = best_sql_types_candidates_per_column_idx[col].back(); if (sql_type != LogicalType::VARCHAR) { all_varchar = false; - if (!CanYouCastIt(best_header_row[col].value, sql_type, sniffer_state_machine.dialect_options, + if (!CanYouCastIt(context, best_header_row[col].value, sql_type, dialect_options, best_header_row[col].IsNull(), options.decimal_separator[0])) { first_row_consistent = false; } @@ -191,33 +226,31 @@ void CSVSniffer::DetectHeader() { } } - if (sniffer_state_machine.options.dialect_options.header.IsSetByUser()) { + if (options.dialect_options.header.IsSetByUser()) { // Header is defined by user, use that. - has_header = sniffer_state_machine.options.dialect_options.header.GetValue(); + has_header = options.dialect_options.header.GetValue(); } // update parser info, and read, generate & set col_names based on previous findings if (has_header) { - sniffer_state_machine.dialect_options.header = true; - if (sniffer_state_machine.options.null_padding && - !sniffer_state_machine.options.dialect_options.skip_rows.IsSetByUser()) { - if (sniffer_state_machine.dialect_options.skip_rows.GetValue() > 0) { - sniffer_state_machine.dialect_options.skip_rows = - sniffer_state_machine.dialect_options.skip_rows.GetValue() - 1; + dialect_options.header = true; + if (options.null_padding && !options.dialect_options.skip_rows.IsSetByUser()) { + if (dialect_options.skip_rows.GetValue() > 0) { + dialect_options.skip_rows = dialect_options.skip_rows.GetValue() - 1; } } case_insensitive_map_t name_collision_count; // get header names from CSV for (idx_t col = 0; col < best_header_row.size(); col++) { - string col_name = best_header_row[col].value.GetString(); + string &col_name = best_header_row[col].value; // generate name if field is empty - if (col_name.empty() || best_header_row[col].IsNull()) { - col_name = GenerateColumnName(sniffer_state_machine.dialect_options.num_cols, col); + if (EmptyHeader(col_name, best_header_row[col].is_null, options.normalize_names)) { + col_name = GenerateColumnName(dialect_options.num_cols, col); } // normalize names or at least trim whitespace - if (sniffer_state_machine.options.normalize_names) { + if (options.normalize_names) { col_name = NormalizeColumnName(col_name); } else { col_name = TrimWhitespace(col_name); @@ -228,27 +261,35 @@ void CSVSniffer::DetectHeader() { name_collision_count[col_name] += 1; col_name = col_name + "_" + to_string(name_collision_count[col_name]); } - names.push_back(col_name); + detected_names.push_back(col_name); name_collision_count[col_name] = 0; } - if (best_header_row.size() < sniffer_state_machine.dialect_options.num_cols && options.null_padding) { - for (idx_t col = best_header_row.size(); col < sniffer_state_machine.dialect_options.num_cols; col++) { - names.push_back(GenerateColumnName(sniffer_state_machine.dialect_options.num_cols, col)); + if (best_header_row.size() < dialect_options.num_cols && options.null_padding) { + for (idx_t col = best_header_row.size(); col < dialect_options.num_cols; col++) { + detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); } - } else if (best_header_row.size() < sniffer_state_machine.dialect_options.num_cols) { + } else if (best_header_row.size() < dialect_options.num_cols) { throw InternalException("Detected header has number of columns inferior to dialect detection"); } } else { - sniffer_state_machine.dialect_options.header = false; - for (idx_t col = 0; col < sniffer_state_machine.dialect_options.num_cols; col++) { - names.push_back(GenerateColumnName(sniffer_state_machine.dialect_options.num_cols, col)); + dialect_options.header = false; + for (idx_t col = 0; col < dialect_options.num_cols; col++) { + detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); } } // If the user provided names, we must replace our header with the user provided names - for (idx_t i = 0; i < MinValue(names.size(), sniffer_state_machine.options.name_list.size()); i++) { - names[i] = sniffer_state_machine.options.name_list[i]; + if (!options.columns_set) { + for (idx_t i = 0; i < MinValue(detected_names.size(), options.name_list.size()); i++) { + detected_names[i] = options.name_list[i]; + } } + return detected_names; +} +void CSVSniffer::DetectHeader() { + auto &sniffer_state_machine = best_candidate->GetStateMachine(); + names = DetectHeaderInternal(buffer_manager->context, best_header_row, sniffer_state_machine, set_columns, + best_sql_types_candidates_per_column_idx, options, *error_handler); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp index 807449ba..11d79c40 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -95,8 +95,8 @@ void CSVSniffer::SetDateFormat(CSVStateMachine &candidate, const string &format_ candidate.dialect_options.date_format[sql_type].Set(strpformat, false); } -bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, const DialectOptions &dialect_options, - const bool is_null, const char decimal_separator) { +bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, const LogicalType &type, + const DialectOptions &dialect_options, const bool is_null, const char decimal_separator) { if (is_null) { return true; } @@ -137,11 +137,11 @@ bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, con } case LogicalTypeId::DOUBLE: { double dummy_value; - return TryDoubleCast(value_ptr, value_size, dummy_value, true, options.decimal_separator[0]); + return TryDoubleCast(value_ptr, value_size, dummy_value, true, decimal_separator); } case LogicalTypeId::FLOAT: { float dummy_value; - return TryDoubleCast(value_ptr, value_size, dummy_value, true, options.decimal_separator[0]); + return TryDoubleCast(value_ptr, value_size, dummy_value, true, decimal_separator); } case LogicalTypeId::DATE: { if (!dialect_options.date_format.find(LogicalTypeId::DATE)->second.GetValue().Empty()) { @@ -150,12 +150,11 @@ bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, con return dialect_options.date_format.find(LogicalTypeId::DATE) ->second.GetValue() .TryParseDate(value, result, error_message); - } else { - idx_t pos; - bool special; - date_t dummy_value; - return Date::TryConvertDate(value_ptr, value_size, pos, dummy_value, special, true); } + idx_t pos; + bool special; + date_t dummy_value; + return Date::TryConvertDate(value_ptr, value_size, pos, dummy_value, special, true); } case LogicalTypeId::TIMESTAMP: { timestamp_t dummy_value; @@ -164,9 +163,8 @@ bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, con return dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) ->second.GetValue() .TryParseTimestamp(value, dummy_value, error_message); - } else { - return Timestamp::TryConvertTimestamp(value_ptr, value_size, dummy_value) == TimestampCastResult::SUCCESS; } + return Timestamp::TryConvertTimestamp(value_ptr, value_size, dummy_value) == TimestampCastResult::SUCCESS; } case LogicalTypeId::TIME: { idx_t pos; @@ -229,9 +227,8 @@ bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, con throw InternalException("Invalid Physical Type for Decimal Value. Physical Type: " + TypeIdToString(type.InternalType())); } - } else { - throw InvalidInputException("Decimals can only have ',' and '.' as decimal separators"); } + throw InvalidInputException("Decimals can only have ',' and '.' as decimal separators"); } case LogicalTypeId::VARCHAR: return true; @@ -240,7 +237,7 @@ bool CSVSniffer::CanYouCastIt(const string_t value, const LogicalType &type, con Value new_value; string error_message; Value str_value(value); - return str_value.TryCastAs(buffer_manager->context, type, new_value, &error_message, true); + return str_value.TryCastAs(context, type, new_value, &error_message, true); } } } @@ -274,8 +271,15 @@ void CSVSniffer::InitializeDateAndTimeStampDetection(CSVStateMachine &candidate, SetDateFormat(candidate, format_candidate.format.back(), sql_type.id()); } +bool ValidSeparator(const string &separator) { + // We use https://en.wikipedia.org/wiki/List_of_date_formats_by_country as reference + return separator == "-" || separator == "." || separator == "/" || separator == " "; +} void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, const LogicalType &sql_type, - const string &separator, string_t &dummy_val) { + const string &separator, const string_t &dummy_val) { + if (!ValidSeparator(separator)) { + return; + } // If it is the first time running date/timestamp detection we must initialize the format variables InitializeDateAndTimeStampDetection(candidate, separator, sql_type); // generate date format candidates the first time through @@ -289,7 +293,7 @@ void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, const while (!type_format_candidates.empty()) { // avoid using exceptions for flow control... auto ¤t_format = candidate.dialect_options.date_format[sql_type.id()].GetValue(); - if (current_format.Parse(dummy_val, result)) { + if (current_format.Parse(dummy_val, result, true)) { format_candidates[sql_type.id()].had_match = true; break; } @@ -317,6 +321,76 @@ void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, const } } +void CSVSniffer::SniffTypes(DataChunk &data_chunk, CSVStateMachine &state_machine, + unordered_map> &info_sql_types_candidates, + idx_t start_idx_detection) { + const idx_t chunk_size = data_chunk.size(); + HasType has_type; + for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { + auto &cur_vector = data_chunk.data[col_idx]; + D_ASSERT(cur_vector.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(cur_vector.GetType() == LogicalType::VARCHAR); + auto vector_data = FlatVector::GetData(cur_vector); + auto null_mask = FlatVector::Validity(cur_vector); + auto &col_type_candidates = info_sql_types_candidates[col_idx]; + for (idx_t row_idx = start_idx_detection; row_idx < chunk_size; row_idx++) { + // col_type_candidates can't be empty since anything in a CSV file should at least be a string + // and we validate utf-8 compatibility when creating the type + D_ASSERT(!col_type_candidates.empty()); + auto cur_top_candidate = col_type_candidates.back(); + // try cast from string to sql_type + while (col_type_candidates.size() > 1) { + const auto &sql_type = col_type_candidates.back(); + // try formatting for date types if the user did not specify one and it starts with numeric + // values. + string separator; + // If Value is not Null, Has a numeric date format, and the current investigated candidate is + // either a timestamp or a date + if (null_mask.RowIsValid(row_idx) && StartsWithNumericDate(separator, vector_data[row_idx]) && + ((col_type_candidates.back().id() == LogicalTypeId::TIMESTAMP && !has_type.timestamp) || + (col_type_candidates.back().id() == LogicalTypeId::DATE && !has_type.date))) { + DetectDateAndTimeStampFormats(state_machine, sql_type, separator, vector_data[row_idx]); + } + // try cast from string to sql_type + if (sql_type == LogicalType::VARCHAR) { + // Nothing to convert it to + continue; + } + if (CanYouCastIt(buffer_manager->context, vector_data[row_idx], sql_type, state_machine.dialect_options, + !null_mask.RowIsValid(row_idx), state_machine.options.decimal_separator[0])) { + break; + } + + if (row_idx != start_idx_detection && cur_top_candidate == LogicalType::BOOLEAN) { + // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we + // immediately pop to varchar. + while (col_type_candidates.back() != LogicalType::VARCHAR) { + col_type_candidates.pop_back(); + } + break; + } + col_type_candidates.pop_back(); + } + } + if (col_type_candidates.back().id() == LogicalTypeId::DATE) { + has_type.date = true; + } + if (col_type_candidates.back().id() == LogicalTypeId::TIMESTAMP) { + has_type.timestamp = true; + } + } +} + +// If we have a predefined date/timestamp format we set it +void CSVSniffer::SetUserDefinedDateTimeFormat(CSVStateMachine &candidate) { + const vector data_time_formats {LogicalTypeId::DATE, LogicalTypeId::TIMESTAMP}; + for (auto &date_time_format : data_time_formats) { + auto &user_option = options.dialect_options.date_format.at(date_time_format); + if (user_option.IsSetByUser()) { + SetDateFormat(candidate, user_option.GetValue().format_specifier, date_time_format); + } + } +} void CSVSniffer::DetectTypes() { idx_t min_varchar_cols = max_columns_found + 1; idx_t min_errors = NumericLimits::Maximum(); @@ -336,72 +410,36 @@ void CSVSniffer::DetectTypes() { // Reset candidate for parsing auto candidate = candidate_cc->UpgradeToStringValueScanner(); - + SetUserDefinedDateTimeFormat(*candidate->state_machine); // Parse chunk and read csv with info candidate auto &data_chunk = candidate->ParseChunk().ToChunk(); - idx_t row_idx = 0; + if (!candidate->error_handler->errors.empty()) { + bool break_loop = false; + for (auto &errors : candidate->error_handler->errors) { + for (auto &error : errors.second) { + if (error.type != CSVErrorType::MAXIMUM_LINE_SIZE) { + break_loop = true; + break; + } + } + } + if (break_loop) { + continue; + } + } + idx_t start_idx_detection = 0; idx_t chunk_size = data_chunk.size(); if (chunk_size > 1 && (!options.dialect_options.header.IsSetByUser() || (options.dialect_options.header.IsSetByUser() && options.dialect_options.header.GetValue()))) { // This means we have more than one row, hence we can use the first row to detect if we have a header - row_idx = 1; + start_idx_detection = 1; } // First line where we start our type detection - const idx_t start_idx_detection = row_idx; - - for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { - auto &cur_vector = data_chunk.data[col_idx]; - D_ASSERT(cur_vector.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(cur_vector.GetType() == LogicalType::VARCHAR); - auto vector_data = FlatVector::GetData(cur_vector); - auto null_mask = FlatVector::Validity(cur_vector); - auto &col_type_candidates = info_sql_types_candidates[col_idx]; - for (row_idx = start_idx_detection; row_idx < chunk_size; row_idx++) { - // col_type_candidates can't be empty since anything in a CSV file should at least be a string - // and we validate utf-8 compatibility when creating the type - D_ASSERT(!col_type_candidates.empty()); - auto cur_top_candidate = col_type_candidates.back(); - // try cast from string to sql_type - while (col_type_candidates.size() > 1) { - const auto &sql_type = col_type_candidates.back(); - // try formatting for date types if the user did not specify one and it starts with numeric - // values. - string separator; - // If Value is not Null, Has a numeric date format, and the current investigated candidate is - // either a timestamp or a date - if (null_mask.RowIsValid(row_idx) && StartsWithNumericDate(separator, vector_data[row_idx]) && - (col_type_candidates.back().id() == LogicalTypeId::TIMESTAMP || - col_type_candidates.back().id() == LogicalTypeId::DATE)) { - DetectDateAndTimeStampFormats(candidate->GetStateMachine(), sql_type, separator, - vector_data[row_idx]); - } - // try cast from string to sql_type - if (sql_type == LogicalType::VARCHAR) { - // Nothing to convert it to - continue; - } - if (CanYouCastIt(vector_data[row_idx], sql_type, sniffing_state_machine.dialect_options, - !null_mask.RowIsValid(row_idx), - sniffing_state_machine.options.decimal_separator[0])) { - break; - } else { - if (row_idx != start_idx_detection && cur_top_candidate == LogicalType::BOOLEAN) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; - } - col_type_candidates.pop_back(); - } - } - } - } + SniffTypes(data_chunk, sniffing_state_machine, info_sql_types_candidates, start_idx_detection); + // Count the number of varchar columns idx_t varchar_cols = 0; - for (idx_t col = 0; col < info_sql_types_candidates.size(); col++) { auto &col_type_candidates = info_sql_types_candidates[col]; // check number of varchar columns @@ -413,9 +451,10 @@ void CSVSniffer::DetectTypes() { // it's good if the dialect creates more non-varchar columns, but only if we sacrifice < 30% of // best_num_cols. - if (varchar_cols(info_sql_types_candidates.size())>( - static_cast(max_columns_found) * 0.7) && - (!options.ignore_errors.GetValue() || candidate->error_handler->errors.size() < min_errors)) { + if (!best_candidate || + (varchar_cols(info_sql_types_candidates.size())>( + static_cast(max_columns_found) * 0.7) && + (!options.ignore_errors.GetValue() || candidate->error_handler->errors.size() < min_errors))) { min_errors = candidate->error_handler->errors.size(); best_header_row.clear(); // we have a new best_options candidate @@ -441,8 +480,9 @@ void CSVSniffer::DetectTypes() { } } if (!best_candidate) { - auto error = CSVError::SniffingError(options.file_path); - error_handler->Error(error, true); + DialectCandidates dialect_candidates(options.dialect_options.state_machine_options); + auto error = CSVError::SniffingError(options, dialect_candidates.Print()); + error_handler->Error(error); } // Assert that it's all good at this point. D_ASSERT(best_candidate && !best_format_candidates.empty()); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp index 8d3fe412..43d69318 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp @@ -2,13 +2,12 @@ #include "duckdb/execution/operator/csv_scanner/csv_casting.hpp" namespace duckdb { - bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const LogicalType &sql_type) { auto &sniffing_state_machine = best_candidate->GetStateMachine(); // try vector-cast from string to sql_type - Vector dummy_result(sql_type); + Vector dummy_result(sql_type, size); if (!sniffing_state_machine.dialect_options.date_format[LogicalTypeId::DATE].GetValue().Empty() && - sql_type == LogicalTypeId::DATE) { + sql_type.id() == LogicalTypeId::DATE) { // use the date format to cast the chunk string error_message; CastParameters parameters(false, &error_message); @@ -17,13 +16,28 @@ bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const Logica dummy_result, size, parameters, line_error); } if (!sniffing_state_machine.dialect_options.date_format[LogicalTypeId::TIMESTAMP].GetValue().Empty() && - sql_type == LogicalTypeId::TIMESTAMP) { + sql_type.id() == LogicalTypeId::TIMESTAMP) { // use the timestamp format to cast the chunk string error_message; CastParameters parameters(false, &error_message); return CSVCast::TryCastTimestampVector(sniffing_state_machine.dialect_options.date_format, parse_chunk_col, dummy_result, size, parameters); } + if ((sql_type.id() == LogicalTypeId::DOUBLE || sql_type.id() == LogicalTypeId::FLOAT) && + options.decimal_separator == ",") { + string error_message; + CastParameters parameters(false, &error_message); + idx_t line_error; + return CSVCast::TryCastFloatingVectorCommaSeparated(options, parse_chunk_col, dummy_result, size, parameters, + sql_type, line_error); + } + if (sql_type.id() == LogicalTypeId::DECIMAL && options.decimal_separator == ",") { + string error_message; + CastParameters parameters(false, &error_message); + idx_t line_error; + return CSVCast::TryCastDecimalVectorCommaSeparated(options, parse_chunk_col, dummy_result, size, parameters, + sql_type, line_error); + } // target type is not varchar: perform a cast string error_message; return VectorOperations::DefaultTryCast(parse_chunk_col, dummy_result, size, &error_message, true); @@ -62,28 +76,29 @@ void CSVSniffer::RefineTypes() { const auto &sql_type = col_type_candidates.back(); if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { break; - } else { - if (col_type_candidates.back() == LogicalType::BOOLEAN && is_bool_type) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; + } + if (col_type_candidates.back() == LogicalType::BOOLEAN && is_bool_type) { + // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we + // immediately pop to varchar. + while (col_type_candidates.back() != LogicalType::VARCHAR) { + col_type_candidates.pop_back(); } - col_type_candidates.pop_back(); + break; } + col_type_candidates.pop_back(); } } // reset parse chunk for the next iteration parse_chunk.Reset(); + parse_chunk.SetCapacity(CSVReaderOptions::sniff_size); } detected_types.clear(); // set sql types for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); if (best_sql_types_candidates_per_column_idx[column_idx].size() == - best_candidate->GetStateMachine().options.auto_type_candidates.size()) { + best_candidate->GetStateMachine().options.auto_type_candidates.size() && + default_null_to_varchar) { d_type = LogicalType::VARCHAR; } detected_types.push_back(d_type); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp index eb9244fc..34fa4146 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp @@ -4,7 +4,7 @@ namespace duckdb { void CSVSniffer::ReplaceTypes() { auto &sniffing_state_machine = best_candidate->GetStateMachine(); manually_set = vector(detected_types.size(), false); - if (sniffing_state_machine.options.sql_type_list.empty()) { + if (sniffing_state_machine.options.sql_type_list.empty() || sniffing_state_machine.options.columns_set) { return; } // user-defined types were supplied for certain columns diff --git a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp index fbe07523..fc9d7f47 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/operator/csv_scanner/csv_state_machine.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp" +#include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" namespace duckdb { @@ -8,6 +9,13 @@ void InitializeTransitionArray(StateMachine &transition_array, const CSVState cu transition_array[i][static_cast(cur_state)] = state; } } + +// Shift and OR to replicate across all bytes +void ShiftAndReplicateBits(uint64_t &value) { + value |= value << 8; + value |= value << 16; + value |= value << 32; +} void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_options) { D_ASSERT(state_machine_cache.find(state_machine_options) == state_machine_cache.end()); // Initialize transition array with default values to the Standard option @@ -24,6 +32,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op case CSVState::ESCAPE: InitializeTransitionArray(transition_array, cur_state, CSVState::INVALID); break; + case CSVState::COMMENT: + InitializeTransitionArray(transition_array, cur_state, CSVState::COMMENT); + break; default: InitializeTransitionArray(transition_array, cur_state, CSVState::STANDARD); break; @@ -33,6 +44,7 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op uint8_t delimiter = static_cast(state_machine_options.delimiter.GetValue()); uint8_t quote = static_cast(state_machine_options.quote.GetValue()); uint8_t escape = static_cast(state_machine_options.escape.GetValue()); + uint8_t comment = static_cast(state_machine_options.comment.GetValue()); auto new_line_id = state_machine_options.new_line.GetValue(); @@ -47,6 +59,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op } else { transition_array[static_cast('\r')][state] = CSVState::RECORD_SEPARATOR; } + if (comment != '\0') { + transition_array[comment][state] = CSVState::COMMENT; + } } // 2) Field Separator State transition_array[delimiter][static_cast(CSVState::DELIMITER)] = CSVState::DELIMITER; @@ -63,6 +78,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op if (delimiter != ' ') { transition_array[' '][static_cast(CSVState::DELIMITER)] = CSVState::EMPTY_SPACE; } + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::DELIMITER)] = CSVState::COMMENT; + } // 3) Record Separator State transition_array[delimiter][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::DELIMITER; @@ -79,6 +97,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op if (delimiter != ' ') { transition_array[' '][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::EMPTY_SPACE; } + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::COMMENT; + } // 4) Carriage Return State transition_array[static_cast('\n')][static_cast(CSVState::CARRIAGE_RETURN)] = @@ -89,6 +110,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op if (delimiter != ' ') { transition_array[' '][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::EMPTY_SPACE; } + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::COMMENT; + } // 5) Quoted State transition_array[quote][static_cast(CSVState::QUOTED)] = CSVState::UNQUOTED; @@ -111,12 +135,15 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op if (state_machine_options.quote == state_machine_options.escape) { transition_array[escape][static_cast(CSVState::UNQUOTED)] = CSVState::QUOTED; } + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::UNQUOTED)] = CSVState::COMMENT; + } // 7) Escaped State transition_array[quote][static_cast(CSVState::ESCAPE)] = CSVState::QUOTED; transition_array[escape][static_cast(CSVState::ESCAPE)] = CSVState::QUOTED; // 8) Not Set - transition_array[delimiter][static_cast(static_cast(CSVState::NOT_SET))] = CSVState::DELIMITER; + transition_array[delimiter][static_cast(CSVState::NOT_SET)] = CSVState::DELIMITER; transition_array[static_cast('\n')][static_cast(CSVState::NOT_SET)] = CSVState::RECORD_SEPARATOR; if (new_line_id == NewLineIdentifier::CARRY_ON) { transition_array[static_cast('\r')][static_cast(CSVState::NOT_SET)] = @@ -125,19 +152,21 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op transition_array[static_cast('\r')][static_cast(CSVState::NOT_SET)] = CSVState::RECORD_SEPARATOR; } - transition_array[static_cast(quote)][static_cast(CSVState::NOT_SET)] = CSVState::QUOTED; + transition_array[quote][static_cast(CSVState::NOT_SET)] = CSVState::QUOTED; if (delimiter != ' ') { transition_array[' '][static_cast(CSVState::NOT_SET)] = CSVState::EMPTY_SPACE; } + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::NOT_SET)] = CSVState::COMMENT; + } // 9) Quoted NewLine transition_array[quote][static_cast(CSVState::QUOTED_NEW_LINE)] = CSVState::UNQUOTED; if (state_machine_options.quote != state_machine_options.escape) { transition_array[escape][static_cast(CSVState::QUOTED_NEW_LINE)] = CSVState::ESCAPE; } - // 10) Empty Value State - transition_array[delimiter][static_cast(static_cast(CSVState::EMPTY_SPACE))] = - CSVState::DELIMITER; + // 10) Empty Value State (Not first value) + transition_array[delimiter][static_cast(CSVState::EMPTY_SPACE)] = CSVState::DELIMITER; transition_array[static_cast('\n')][static_cast(CSVState::EMPTY_SPACE)] = CSVState::RECORD_SEPARATOR; if (new_line_id == NewLineIdentifier::CARRY_ON) { @@ -148,15 +177,31 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op CSVState::RECORD_SEPARATOR; } transition_array[quote][static_cast(CSVState::EMPTY_SPACE)] = CSVState::QUOTED; + if (comment != '\0') { + transition_array[comment][static_cast(CSVState::EMPTY_SPACE)] = CSVState::COMMENT; + } + + // 11) Comment State + transition_array[static_cast('\n')][static_cast(CSVState::COMMENT)] = CSVState::RECORD_SEPARATOR; + if (new_line_id == NewLineIdentifier::CARRY_ON) { + transition_array[static_cast('\r')][static_cast(CSVState::COMMENT)] = + CSVState::CARRIAGE_RETURN; + } else { + transition_array[static_cast('\r')][static_cast(CSVState::COMMENT)] = + CSVState::RECORD_SEPARATOR; + } + // Initialize characters we can skip during processing, for Standard and Quoted states for (idx_t i = 0; i < StateMachine::NUM_TRANSITIONS; i++) { transition_array.skip_standard[i] = true; transition_array.skip_quoted[i] = true; + transition_array.skip_comment[i] = true; } // For standard states we only care for delimiters \r and \n transition_array.skip_standard[delimiter] = false; transition_array.skip_standard[static_cast('\n')] = false; transition_array.skip_standard[static_cast('\r')] = false; + transition_array.skip_standard[comment] = false; // For quoted we only care about quote, escape and for delimiters \r and \n transition_array.skip_quoted[quote] = false; @@ -164,6 +209,9 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op transition_array.skip_quoted[static_cast('\n')] = false; transition_array.skip_quoted[static_cast('\r')] = false; + transition_array.skip_comment[static_cast('\r')] = false; + transition_array.skip_comment[static_cast('\n')] = false; + transition_array.delimiter = delimiter; transition_array.new_line = static_cast('\n'); transition_array.carriage_return = static_cast('\r'); @@ -171,36 +219,32 @@ void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_op transition_array.escape = escape; // Shift and OR to replicate across all bytes - transition_array.delimiter |= transition_array.delimiter << 8; - transition_array.delimiter |= transition_array.delimiter << 16; - transition_array.delimiter |= transition_array.delimiter << 32; - - transition_array.new_line |= transition_array.new_line << 8; - transition_array.new_line |= transition_array.new_line << 16; - transition_array.new_line |= transition_array.new_line << 32; - - transition_array.carriage_return |= transition_array.carriage_return << 8; - transition_array.carriage_return |= transition_array.carriage_return << 16; - transition_array.carriage_return |= transition_array.carriage_return << 32; - - transition_array.quote |= transition_array.quote << 8; - transition_array.quote |= transition_array.quote << 16; - transition_array.quote |= transition_array.quote << 32; - - transition_array.escape |= transition_array.escape << 8; - transition_array.escape |= transition_array.escape << 16; - transition_array.escape |= transition_array.escape << 32; + ShiftAndReplicateBits(transition_array.delimiter); + ShiftAndReplicateBits(transition_array.new_line); + ShiftAndReplicateBits(transition_array.carriage_return); + ShiftAndReplicateBits(transition_array.quote); + ShiftAndReplicateBits(transition_array.escape); + ShiftAndReplicateBits(transition_array.comment); } CSVStateMachineCache::CSVStateMachineCache() { + auto default_quote = DialectCandidates::GetDefaultQuote(); + auto default_escape = DialectCandidates::GetDefaultEscape(); + auto default_quote_rule = DialectCandidates::GetDefaultQuoteRule(); + auto default_delimiter = DialectCandidates::GetDefaultDelimiter(); + auto default_comment = DialectCandidates::GetDefaultComment(); + for (auto quoterule : default_quote_rule) { const auto "e_candidates = default_quote[static_cast(quoterule)]; for (const auto "e : quote_candidates) { for (const auto &delimiter : default_delimiter) { const auto &escape_candidates = default_escape[static_cast(quoterule)]; for (const auto &escape : escape_candidates) { - Insert({delimiter, quote, escape, NewLineIdentifier::SINGLE}); - Insert({delimiter, quote, escape, NewLineIdentifier::CARRY_ON}); + for (const auto &comment : default_comment) { + Insert({delimiter, quote, escape, comment, NewLineIdentifier::SINGLE_N}); + Insert({delimiter, quote, escape, comment, NewLineIdentifier::SINGLE_R}); + Insert({delimiter, quote, escape, comment, NewLineIdentifier::CARRY_ON}); + } } } } diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index c0227041..e3589486 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -1,13 +1,16 @@ #include "duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp" -#include "duckdb/function/table/read_csv.hpp" + #include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" +#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" +#include "duckdb/function/table/read_csv.hpp" namespace duckdb { +CSVUnionData::~CSVUnionData() { +} CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr buffer_manager_p, shared_ptr state_machine_p, const CSVReaderOptions &options_p, - const ReadCSVData &bind_data, const vector &column_ids, - vector &file_schema) + const ReadCSVData &bind_data, const vector &column_ids, CSVSchema &file_schema) : file_path(options_p.file_path), file_idx(0), buffer_manager(std::move(buffer_manager_p)), state_machine(std::move(state_machine_p)), file_size(buffer_manager->file_handle->FileSize()), error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), @@ -23,7 +26,8 @@ CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr bu bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); InitializeFileNamesTypes(); return; - } else if (!bind_data.column_info.empty()) { + } + if (!bind_data.column_info.empty()) { // Serialized Union By name names = bind_data.column_info[0].names; types = bind_data.column_info[0].types; @@ -32,51 +36,56 @@ CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr bu InitializeFileNamesTypes(); return; } - names = bind_data.return_names; - types = bind_data.return_types; - file_schema = bind_data.return_types; + names = bind_data.csv_names; + types = bind_data.csv_types; + file_schema.Initialize(names, types, file_path); multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); InitializeFileNamesTypes(); + SetStart(); +} + +void CSVFileScan::SetStart() { + idx_t rows_to_skip = options.GetSkipRows() + state_machine->dialect_options.header.GetValue(); + rows_to_skip = std::max(rows_to_skip, state_machine->dialect_options.rows_until_header + + state_machine->dialect_options.header.GetValue()); + if (rows_to_skip == 0) { + start_iterator.first_one = true; + return; + } + SkipScanner skip_scanner(buffer_manager, state_machine, error_handler, rows_to_skip); + skip_scanner.ParseChunk(); + start_iterator = skip_scanner.GetIterator(); } CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, const CSVReaderOptions &options_p, const idx_t file_idx_p, const ReadCSVData &bind_data, const vector &column_ids, - const vector &file_schema) + CSVSchema &file_schema, bool per_file_single_threaded) : file_path(file_path_p), file_idx(file_idx_p), error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { auto multi_file_reader = MultiFileReader::CreateDefault("CSV Scan"); - if (file_idx < bind_data.union_readers.size()) { - // we are doing UNION BY NAME - fetch the options from the union reader for this file - optional_ptr union_reader_ptr; - if (file_idx == 0) { - union_reader_ptr = bind_data.initial_reader.get(); - } else { - union_reader_ptr = bind_data.union_readers[file_idx].get(); - } - if (union_reader_ptr) { - auto &union_reader = *union_reader_ptr; - // Initialize Buffer Manager - buffer_manager = union_reader.buffer_manager; - // Initialize On Disk and Size of file - on_disk_file = union_reader.on_disk_file; - file_size = union_reader.file_size; - names = union_reader.GetNames(); - options = union_reader.options; - types = union_reader.GetTypes(); - state_machine = union_reader.state_machine; - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, - bind_data.return_types, bind_data.return_names, column_ids, nullptr, - file_path, context, nullptr); - - InitializeFileNamesTypes(); - return; - } + if (file_idx == 0 && bind_data.initial_reader) { + auto &union_reader = *bind_data.initial_reader; + // Initialize Buffer Manager + buffer_manager = union_reader.buffer_manager; + // Initialize On Disk and Size of file + on_disk_file = union_reader.on_disk_file; + file_size = union_reader.file_size; + names = union_reader.GetNames(); + options = union_reader.options; + types = union_reader.GetTypes(); + state_machine = union_reader.state_machine; + multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, + bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); + + InitializeFileNamesTypes(); + SetStart(); + return; } // Initialize Buffer Manager - buffer_manager = make_shared_ptr(context, options, file_path, file_idx); + buffer_manager = make_shared_ptr(context, options, file_path, file_idx, per_file_single_threaded); // Initialize On Disk and Size of file on_disk_file = buffer_manager->file_handle->OnDiskFile(); file_size = buffer_manager->file_handle->FileSize(); @@ -84,13 +93,21 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons auto &state_machine_cache = CSVStateMachineCache::Get(context); if (file_idx < bind_data.column_info.size()) { - // Serialized Union By name + // (Serialized) Union By name names = bind_data.column_info[file_idx].names; types = bind_data.column_info[file_idx].types; - options.dialect_options.num_cols = names.size(); - if (options.auto_detect) { - CSVSniffer sniffer(options, buffer_manager, state_machine_cache); - sniffer.SniffCSV(); + if (file_idx < bind_data.union_readers.size()) { + // union readers - use cached options + D_ASSERT(names == bind_data.union_readers[file_idx]->names); + D_ASSERT(types == bind_data.union_readers[file_idx]->types); + options = bind_data.union_readers[file_idx]->options; + } else { + // Serialized union by name - sniff again + options.dialect_options.num_cols = names.size(); + if (options.auto_detect) { + CSVSniffer sniffer(options, buffer_manager, state_machine_cache); + sniffer.SniffCSV(); + } } state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); @@ -98,40 +115,41 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); InitializeFileNamesTypes(); + SetStart(); return; } - // Sniff it (We only really care about dialect detection, if types or number of columns are different this will - // error out during scanning) - if (options.auto_detect && file_idx > 0) { - CSVSniffer sniffer(options, buffer_manager, state_machine_cache); - auto result = sniffer.SniffCSV(); - if (!file_schema.empty()) { - if (!options.file_options.filename && !options.file_options.hive_partitioning && - file_schema.size() != result.return_types.size()) { - throw InvalidInputException("Mismatch between the schema of different files"); - } + // Sniff it! + names = bind_data.csv_names; + types = bind_data.csv_types; + if (options.auto_detect && bind_data.files.size() > 1) { + if (file_schema.Empty()) { + CSVSniffer sniffer(options, buffer_manager, state_machine_cache); + auto result = sniffer.SniffCSV(); + file_schema.Initialize(result.names, result.return_types, options.file_path); + } else if (file_idx > 0 && buffer_manager->file_handle->FileSize() > 0) { + options.file_path = file_path; + CSVSniffer sniffer(options, buffer_manager, state_machine_cache, false); + auto result = sniffer.AdaptiveSniff(file_schema); + names = result.names; + types = result.return_types; } } if (options.dialect_options.num_cols == 0) { // We need to define the number of columns, if the sniffer is not running this must be in the sql_type_list options.dialect_options.num_cols = options.sql_type_list.size(); } - if (options.dialect_options.state_machine_options.new_line == NewLineIdentifier::NOT_SET) { options.dialect_options.state_machine_options.new_line = CSVSniffer::DetectNewLineDelimiter(*buffer_manager); } - - names = bind_data.csv_names; - types = bind_data.csv_types; state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); InitializeFileNamesTypes(); + SetStart(); } -CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVReaderOptions &options_p) +CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, const CSVReaderOptions &options_p) : file_path(file_name), file_idx(0), error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { buffer_manager = make_shared_ptr(context, options, file_path, file_idx); @@ -156,6 +174,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVRea // Initialize State Machine state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); + SetStart(); } void CSVFileScan::InitializeFileNamesTypes() { @@ -180,9 +199,11 @@ void CSVFileScan::InitializeFileNamesTypes() { } // We need to be sure that our types are also following the cast_map - for (idx_t i = 0; i < reader_data.column_ids.size(); i++) { - if (reader_data.cast_map.find(reader_data.column_ids[i]) != reader_data.cast_map.end()) { - file_types[i] = reader_data.cast_map[reader_data.column_ids[i]]; + if (!reader_data.cast_map.empty()) { + for (idx_t i = 0; i < reader_data.column_ids.size(); i++) { + if (reader_data.cast_map.find(reader_data.column_ids[i]) != reader_data.cast_map.end()) { + file_types[i] = reader_data.cast_map[reader_data.column_ids[i]]; + } } } @@ -195,7 +216,7 @@ void CSVFileScan::InitializeFileNamesTypes() { file_types = sorted_types; } -const string &CSVFileScan::GetFileName() { +const string &CSVFileScan::GetFileName() const { return file_path; } const vector &CSVFileScan::GetNames() { diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index acf19eff..4f3e9dce 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -1,9 +1,11 @@ #include "duckdb/execution/operator/csv_scanner/global_csv_state.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/execution/operator/csv_scanner/scanner_boundary.hpp" + #include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" +#include "duckdb/execution/operator/csv_scanner/scanner_boundary.hpp" +#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" #include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" #include "duckdb/main/appender.hpp" +#include "duckdb/main/client_data.hpp" namespace duckdb { @@ -22,7 +24,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptr(context, files[0], options, 0U, bind_data, column_ids, file_schema)); + make_uniq(context, files[0], options, 0U, bind_data, column_ids, file_schema, false)); }; // There are situations where we only support single threaded scanning bool many_csv_files = files.size() > 1 && files.size() > system_threads * 2; @@ -30,13 +32,18 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptrbuffer_manager->GetBuffer(0)->actual_size; - current_boundary = CSVIterator(0, 0, 0, 0, buffer_size); + current_boundary = file_scans.back()->start_iterator; + current_boundary.SetCurrentBoundaryToPosition(single_threaded); + if (current_boundary.done && context.client_data->debug_set_max_line_length) { + context.client_data->debug_max_line_length = current_boundary.pos.buffer_pos; } - current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, 0); + current_buffer_in_use = + make_shared_ptr(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); +} + +bool CSVGlobalState::IsDone() const { + lock_guard parallel_lock(main_mutex); + return current_boundary.done; } double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { @@ -44,13 +51,24 @@ double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { idx_t total_files = bind_data.files.size(); // get the progress WITHIN the current file double percentage = 0; - if (file_scans.back()->file_size == 0) { + if (file_scans.front()->file_size == 0) { percentage = 1.0; } else { // for compressed files, readed bytes may greater than files size. for (auto &file : file_scans) { - percentage += - (double(1) / double(total_files)) * std::min(1.0, double(file->bytes_read) / double(file->file_size)); + double file_progress; + if (!file->buffer_manager) { + // We are done with this file, so it's 100% + file_progress = 1.0; + } else if (file->buffer_manager->file_handle->compression_type == FileCompressionType::GZIP || + file->buffer_manager->file_handle->compression_type == FileCompressionType::ZSTD) { + // This file is not done, and is a compressed file + file_progress = file->buffer_manager->file_handle->GetProgress(); + } else { + file_progress = static_cast(file->bytes_read); + } + // This file is an uncompressed file, so we use the more price bytes_read from the scanner + percentage += (double(1) / double(total_files)) * std::min(1.0, file_progress / double(file->file_size)); } } return percentage * 100; @@ -58,29 +76,45 @@ double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { unique_ptr CSVGlobalState::Next(optional_ptr previous_scanner) { if (single_threaded) { - idx_t cur_idx = last_file_idx++; - if (cur_idx >= bind_data.files.size()) { - return nullptr; - } - shared_ptr current_file; - if (cur_idx == 0) { - current_file = file_scans.back(); - } else { - lock_guard parallel_lock(main_mutex); - file_scans.emplace_back(make_shared_ptr(context, bind_data.files[cur_idx], bind_data.options, - cur_idx, bind_data, column_ids, file_schema)); - current_file = file_scans.back(); - } - if (previous_scanner) { - lock_guard parallel_lock(main_mutex); - previous_scanner->buffer_tracker.reset(); - current_buffer_in_use.reset(); - previous_scanner->csv_file_scan->Finish(); - } - auto csv_scanner = - make_uniq(scanner_idx++, current_file->buffer_manager, current_file->state_machine, - current_file->error_handler, current_file, false, current_boundary); - return csv_scanner; + idx_t cur_idx; + bool empty_file = false; + do { + { + lock_guard parallel_lock(main_mutex); + cur_idx = last_file_idx++; + if (cur_idx >= bind_data.files.size()) { + // No more files to scan + return nullptr; + } + if (cur_idx == 0) { + D_ASSERT(!previous_scanner); + auto current_file = file_scans.front(); + return make_uniq(scanner_idx++, current_file->buffer_manager, + current_file->state_machine, current_file->error_handler, + current_file, false, current_boundary); + } + } + auto file_scan = make_shared_ptr(context, bind_data.files[cur_idx], bind_data.options, cur_idx, + bind_data, column_ids, file_schema, true); + empty_file = file_scan->file_size == 0; + if (!empty_file) { + lock_guard parallel_lock(main_mutex); + file_scans.emplace_back(std::move(file_scan)); + auto current_file = file_scans.back(); + current_boundary = current_file->start_iterator; + current_boundary.SetCurrentBoundaryToPosition(single_threaded); + current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, + current_boundary.GetBufferIdx()); + if (previous_scanner) { + previous_scanner->buffer_tracker.reset(); + current_buffer_in_use.reset(); + previous_scanner->csv_file_scan->Finish(); + } + return make_uniq(scanner_idx++, current_file->buffer_manager, + current_file->state_machine, current_file->error_handler, + current_file, false, current_boundary); + } + } while (empty_file); } lock_guard parallel_lock(main_mutex); if (finished) { @@ -108,20 +142,24 @@ unique_ptr CSVGlobalState::Next(optional_ptr(context, bind_data.files[current_file_idx], - bind_data.options, current_file_idx, bind_data, - column_ids, file_schema)); - // And re-start the boundary-iterator - auto buffer_size = file_scans.back()->buffer_manager->GetBuffer(0)->actual_size; - current_boundary = CSVIterator(current_file_idx, 0, 0, 0, buffer_size); - current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, 0); - } else { - // If not we are done with this CSV Scanning - finished = true; - } + do { + auto current_file_idx = file_scans.back()->file_idx + 1; + if (current_file_idx < bind_data.files.size()) { + // If we have a next file we have to construct the file scan for that + file_scans.emplace_back(make_shared_ptr(context, bind_data.files[current_file_idx], + bind_data.options, current_file_idx, bind_data, + column_ids, file_schema, false)); + // And re-start the boundary-iterator + current_boundary = file_scans.back()->start_iterator; + current_boundary.SetCurrentBoundaryToPosition(single_threaded); + current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, + current_boundary.GetBufferIdx()); + } else { + // If not we are done with this CSV Scanning + finished = true; + break; + } + } while (current_boundary.done); } // We initialize the scan return csv_scanner; @@ -132,7 +170,7 @@ idx_t CSVGlobalState::MaxThreads() const { if (single_threaded) { return system_threads; } - idx_t total_threads = file_scans.back()->file_size / CSVIterator::BYTES_PER_THREAD + 1; + idx_t total_threads = file_scans.front()->file_size / CSVIterator::BYTES_PER_THREAD + 1; if (total_threads < system_threads) { return total_threads; diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp index 208c2337..e7a41f3a 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -15,10 +15,10 @@ LinesPerBoundary::LinesPerBoundary(idx_t boundary_idx_p, idx_t lines_in_batch_p) CSVErrorHandler::CSVErrorHandler(bool ignore_errors_p) : ignore_errors(ignore_errors_p) { } -void CSVErrorHandler::ThrowError(CSVError csv_error) { +void CSVErrorHandler::ThrowError(const CSVError &csv_error) { std::ostringstream error; if (PrintLineNumber(csv_error)) { - error << "CSV Error on Line: " << GetLine(csv_error.error_info) << '\n'; + error << "CSV Error on Line: " << GetLineInternal(csv_error.error_info) << '\n'; if (!csv_error.csv_row.empty()) { error << "Original Line: " << csv_error.csv_row << '\n'; } @@ -30,11 +30,11 @@ void CSVErrorHandler::ThrowError(CSVError csv_error) { } switch (csv_error.type) { - case CSVErrorType::CAST_ERROR: + case CAST_ERROR: throw ConversionException(error.str()); - case CSVErrorType::COLUMN_NAME_TYPE_MISMATCH: + case COLUMN_NAME_TYPE_MISMATCH: throw BinderException(error.str()); - case CSVErrorType::NULLPADDED_QUOTED_NEW_VALUE: + case NULLPADDED_QUOTED_NEW_VALUE: throw ParameterNotAllowedException(error.str()); default: throw InvalidInputException(error.str()); @@ -42,8 +42,8 @@ void CSVErrorHandler::ThrowError(CSVError csv_error) { } void CSVErrorHandler::Error(CSVError csv_error, bool force_error) { + lock_guard parallel_lock(main_mutex); if ((ignore_errors && !force_error) || (PrintLineNumber(csv_error) && !CanGetLine(csv_error.GetBoundaryIndex()))) { - lock_guard parallel_lock(main_mutex); // We store this error, we can't throw it now, or we are ignoring it errors[csv_error.error_info].push_back(std::move(csv_error)); return; @@ -53,15 +53,12 @@ void CSVErrorHandler::Error(CSVError csv_error, bool force_error) { } void CSVErrorHandler::ErrorIfNeeded() { - CSVError first_error; - { - lock_guard parallel_lock(main_mutex); - if (ignore_errors || errors.empty()) { - // Nothing to error - return; - } - first_error = errors.begin()->second[0]; + lock_guard parallel_lock(main_mutex); + if (ignore_errors || errors.empty()) { + // Nothing to error + return; } + CSVError first_error = errors.begin()->second[0]; if (CanGetLine(first_error.error_info.boundary_idx)) { ThrowError(first_error); @@ -82,13 +79,18 @@ void CSVErrorHandler::NewMaxLineSize(idx_t scan_line_size) { max_line_length = std::max(scan_line_size, max_line_length); } +bool CSVErrorHandler::AnyErrors() { + lock_guard parallel_lock(main_mutex); + return !errors.empty(); +} + CSVError::CSVError(string error_message_p, CSVErrorType type_p, LinesPerBoundary error_info_p) : error_message(std::move(error_message_p)), type(type_p), error_info(error_info_p) { } CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx_p, string csv_row_p, LinesPerBoundary error_info_p, idx_t row_byte_position, optional_idx byte_position_p, - const CSVReaderOptions &reader_options, const string &fixes) + const CSVReaderOptions &reader_options, const string &fixes, const string ¤t_path) : error_message(std::move(error_message_p)), type(type_p), column_idx(column_idx_p), csv_row(std::move(csv_row_p)), error_info(error_info_p), row_byte_position(row_byte_position), byte_position(byte_position_p) { // What were the options @@ -98,7 +100,7 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx } error << error_message << '\n'; error << fixes << '\n'; - error << reader_options.ToString(); + error << reader_options.ToString(current_path); error << '\n'; full_error_message = error.str(); } @@ -112,7 +114,7 @@ CSVError CSVError::ColumnTypesError(case_insensitive_map_t sql_types_per_ } } if (sql_types_per_column.empty()) { - return CSVError("", CSVErrorType::COLUMN_NAME_TYPE_MISMATCH, {}); + return CSVError("", COLUMN_NAME_TYPE_MISMATCH, {}); } string exception = "COLUMN_TYPES error: Columns with names: "; for (auto &col : sql_types_per_column) { @@ -120,7 +122,7 @@ CSVError CSVError::ColumnTypesError(case_insensitive_map_t sql_types_per_ } exception.pop_back(); exception += " do not exist in the CSV File"; - return CSVError(exception, CSVErrorType::COLUMN_NAME_TYPE_MISMATCH, {}); + return CSVError(exception, COLUMN_NAME_TYPE_MISMATCH, {}); } void CSVError::RemoveNewLine(string &error) { @@ -129,7 +131,7 @@ void CSVError::RemoveNewLine(string &error) { CSVError CSVError::CastError(const CSVReaderOptions &options, string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type) { + optional_idx byte_position, LogicalTypeId type, const string ¤t_path) { std::ostringstream error; // Which column error << "Error when converting column \"" << column_name << "\". "; @@ -154,12 +156,12 @@ CSVError CSVError::CastError(const CSVReaderOptions &options, string &column_nam << '\n'; } - return CSVError(error.str(), CSVErrorType::CAST_ERROR, column_idx, csv_row, error_info, row_byte_position, - byte_position, options, how_to_fix_it.str()); + return CSVError(error.str(), CAST_ERROR, column_idx, csv_row, error_info, row_byte_position, byte_position, options, + how_to_fix_it.str(), current_path); } CSVError CSVError::LineSizeError(const CSVReaderOptions &options, idx_t actual_size, LinesPerBoundary error_info, - string &csv_row, idx_t byte_position) { + string &csv_row, idx_t byte_position, const string ¤t_path) { std::ostringstream error; error << "Maximum line size of " << options.maximum_line_size << " bytes exceeded. "; error << "Actual Size:" << actual_size << " bytes." << '\n'; @@ -168,42 +170,146 @@ CSVError CSVError::LineSizeError(const CSVReaderOptions &options, idx_t actual_s how_to_fix_it << "Possible Solution: Change the maximum length size, e.g., max_line_size=" << actual_size + 1 << "\n"; - return CSVError(error.str(), CSVErrorType::MAXIMUM_LINE_SIZE, 0, csv_row, error_info, byte_position, byte_position, - options, how_to_fix_it.str()); + return CSVError(error.str(), MAXIMUM_LINE_SIZE, 0, csv_row, error_info, byte_position, byte_position, options, + how_to_fix_it.str(), current_path); } -CSVError CSVError::SniffingError(string &file_path) { +CSVError CSVError::HeaderSniffingError(const CSVReaderOptions &options, const vector &best_header_row, + idx_t column_count, char delimiter) { std::ostringstream error; - // Which column - error << "Error when sniffing file \"" << file_path << "\"." << '\n'; - error << "CSV options could not be auto-detected. Consider setting parser options manually." << '\n'; - return CSVError(error.str(), CSVErrorType::SNIFFING, {}); + // 1. Which file + error << "Error when sniffing file \"" << options.file_path << "\"." << '\n'; + // 2. What's the error + error << "It was not possible to detect the CSV Header, due to the header having less columns than expected" + << '\n'; + // 2.1 What's the expected number of columns + error << "Number of expected columns: " << column_count << ". Actual number of columns " << best_header_row.size() + << '\n'; + // 2.2 What was the detected row + error << "Detected row as Header:" << '\n'; + for (idx_t i = 0; i < best_header_row.size(); i++) { + if (best_header_row[i].is_null) { + error << "NULL"; + } else { + error << best_header_row[i].value; + } + if (i < best_header_row.size() - 1) { + error << delimiter << " "; + } + } + error << "\n"; + + // 3. Suggest how to fix it! + error << "Possible fixes:" << '\n'; + // header + if (!options.dialect_options.header.IsSetByUser()) { + error << "* Set header (header = true) if your CSV has a header, or (header = false) if it doesn't" << '\n'; + } else { + error << "* Header is set to \'" << options.dialect_options.header.GetValue() << "\'. Consider unsetting it." + << '\n'; + } + // skip_rows + if (!options.dialect_options.skip_rows.IsSetByUser()) { + error << "* Set skip (skip=${n}) to skip ${n} lines at the top of the file" << '\n'; + } else { + error << "* Skip is set to \'" << options.dialect_options.skip_rows.GetValue() << "\'. Consider unsetting it." + << '\n'; + } + // ignore_errors + if (!options.ignore_errors.GetValue()) { + error << "* Enable ignore errors (ignore_errors=true) to ignore potential errors" << '\n'; + } + // null_padding + if (!options.null_padding) { + error << "* Enable null padding (null_padding=true) to pad missing columns with NULL values" << '\n'; + } + return CSVError(error.str(), SNIFFING, {}); } -CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info) { +CSVError CSVError::SniffingError(const CSVReaderOptions &options, const string &search_space) { + std::ostringstream error; + // 1. Which file + error << "Error when sniffing file \"" << options.file_path << "\"." << '\n'; + // 2. What's the error + error << "It was not possible to automatically detect the CSV Parsing dialect/types" << '\n'; + + // 2. What was the search space? + error << "The search space used was:" << '\n'; + error << search_space; + // 3. Suggest how to fix it! + error << "Possible fixes:" << '\n'; + // 3.1 Inform the reader of the dialect + // delimiter + if (!options.dialect_options.state_machine_options.delimiter.IsSetByUser()) { + error << "* Set delimiter (e.g., delim=\',\')" << '\n'; + } else { + error << "* Delimiter is set to \'" << options.dialect_options.state_machine_options.delimiter.GetValue() + << "\'. Consider unsetting it." << '\n'; + } + // quote + if (!options.dialect_options.state_machine_options.quote.IsSetByUser()) { + error << "* Set quote (e.g., quote=\'\"\')" << '\n'; + } else { + error << "* Quote is set to \'" << options.dialect_options.state_machine_options.quote.GetValue() + << "\'. Consider unsetting it." << '\n'; + } + // escape + if (!options.dialect_options.state_machine_options.escape.IsSetByUser()) { + error << "* Set escape (e.g., escape=\'\"\')" << '\n'; + } else { + error << "* Escape is set to \'" << options.dialect_options.state_machine_options.escape.GetValue() + << "\'. Consider unsetting it." << '\n'; + } + // comment + if (!options.dialect_options.state_machine_options.comment.IsSetByUser()) { + error << "* Set comment (e.g., comment=\'#\')" << '\n'; + } else { + error << "* Comment is set to \'" << options.dialect_options.state_machine_options.comment.GetValue() + << "\'. Consider unsetting it." << '\n'; + } + // 3.2 skip_rows + if (!options.dialect_options.skip_rows.IsSetByUser()) { + error << "* Set skip (skip=${n}) to skip ${n} lines at the top of the file" << '\n'; + } + // 3.3 ignore_errors + if (!options.ignore_errors.GetValue()) { + error << "* Enable ignore errors (ignore_errors=true) to ignore potential errors" << '\n'; + } + // 3.4 null_padding + if (!options.null_padding) { + error << "* Enable null padding (null_padding=true) to pad missing columns with NULL values" << '\n'; + } + error << "* Check you are using the correct file compression, otherwise set it (e.g., compression = \'zstd\')" + << '\n'; + + return CSVError(error.str(), SNIFFING, {}); +} + +CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, + const string ¤t_path) { std::ostringstream error; error << " The parallel scanner does not support null_padding in conjunction with quoted new lines. Please " "disable the parallel csv reader with parallel=false" << '\n'; // What were the options - error << options.ToString(); - return CSVError(error.str(), CSVErrorType::NULLPADDED_QUOTED_NEW_VALUE, error_info); + error << options.ToString(current_path); + return CSVError(error.str(), NULLPADDED_QUOTED_NEW_VALUE, error_info); } CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position) { + optional_idx byte_position, const string ¤t_path) { std::ostringstream error; error << "Value with unterminated quote found." << '\n'; std::ostringstream how_to_fix_it; how_to_fix_it << "Possible Solution: Enable ignore errors (ignore_errors=true) to skip this row" << '\n'; - return CSVError(error.str(), CSVErrorType::UNTERMINATED_QUOTES, current_column, csv_row, error_info, - row_byte_position, byte_position, options, how_to_fix_it.str()); + return CSVError(error.str(), UNTERMINATED_QUOTES, current_column, csv_row, error_info, row_byte_position, + byte_position, options, how_to_fix_it.str(), current_path); } CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position) { + optional_idx byte_position, const string ¤t_path) { std::ostringstream error; // We don't have a fix for this std::ostringstream how_to_fix_it; @@ -217,37 +323,38 @@ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, i // How many columns were expected and how many were found error << "Expected Number of Columns: " << options.dialect_options.num_cols << " Found: " << actual_columns + 1; if (actual_columns >= options.dialect_options.num_cols) { - return CSVError(error.str(), CSVErrorType::TOO_MANY_COLUMNS, actual_columns, csv_row, error_info, - row_byte_position, byte_position.GetIndex() - 1, options, how_to_fix_it.str()); + return CSVError(error.str(), TOO_MANY_COLUMNS, actual_columns, csv_row, error_info, row_byte_position, + byte_position.GetIndex() - 1, options, how_to_fix_it.str(), current_path); } else { - return CSVError(error.str(), CSVErrorType::TOO_FEW_COLUMNS, actual_columns, csv_row, error_info, - row_byte_position, byte_position.GetIndex() - 1, options, how_to_fix_it.str()); + return CSVError(error.str(), TOO_FEW_COLUMNS, actual_columns, csv_row, error_info, row_byte_position, + byte_position.GetIndex() - 1, options, how_to_fix_it.str(), current_path); } } CSVError CSVError::InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, - string &csv_row, idx_t row_byte_position, optional_idx byte_position) { + string &csv_row, idx_t row_byte_position, optional_idx byte_position, + const string ¤t_path) { std::ostringstream error; // How many columns were expected and how many were found error << "Invalid unicode (byte sequence mismatch) detected." << '\n'; std::ostringstream how_to_fix_it; how_to_fix_it << "Possible Solution: Enable ignore errors (ignore_errors=true) to skip this row" << '\n'; - return CSVError(error.str(), CSVErrorType::INVALID_UNICODE, current_column, csv_row, error_info, row_byte_position, - byte_position, options, how_to_fix_it.str()); + return CSVError(error.str(), INVALID_UNICODE, current_column, csv_row, error_info, row_byte_position, byte_position, + options, how_to_fix_it.str(), current_path); } -bool CSVErrorHandler::PrintLineNumber(CSVError &error) { +bool CSVErrorHandler::PrintLineNumber(const CSVError &error) const { if (!print_line) { return false; } switch (error.type) { - case CSVErrorType::CAST_ERROR: - case CSVErrorType::UNTERMINATED_QUOTES: - case CSVErrorType::TOO_FEW_COLUMNS: - case CSVErrorType::TOO_MANY_COLUMNS: - case CSVErrorType::MAXIMUM_LINE_SIZE: - case CSVErrorType::NULLPADDED_QUOTED_NEW_VALUE: - case CSVErrorType::INVALID_UNICODE: + case CAST_ERROR: + case UNTERMINATED_QUOTES: + case TOO_FEW_COLUMNS: + case TOO_MANY_COLUMNS: + case MAXIMUM_LINE_SIZE: + case NULLPADDED_QUOTED_NEW_VALUE: + case INVALID_UNICODE: return true; default: return false; @@ -265,6 +372,9 @@ bool CSVErrorHandler::CanGetLine(idx_t boundary_index) { idx_t CSVErrorHandler::GetLine(const LinesPerBoundary &error_info) { lock_guard parallel_lock(main_mutex); + return GetLineInternal(error_info); +} +idx_t CSVErrorHandler::GetLineInternal(const LinesPerBoundary &error_info) { // We start from one, since the lines are 1-indexed idx_t current_line = 1 + error_info.lines_in_batch; for (idx_t boundary_idx = 0; boundary_idx < error_info.boundary_idx; boundary_idx++) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index ded31aee..21f910ec 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -89,8 +89,8 @@ void CSVReaderOptions::SetEscape(const string &input) { this->dialect_options.state_machine_options.escape.Set(escape_str[0]); } -int64_t CSVReaderOptions::GetSkipRows() const { - return NumericCast(this->dialect_options.skip_rows.GetValue()); +idx_t CSVReaderOptions::GetSkipRows() const { + return NumericCast(this->dialect_options.skip_rows.GetValue()); } void CSVReaderOptions::SetSkipRows(int64_t skip_rows) { @@ -130,13 +130,41 @@ void CSVReaderOptions::SetQuote(const string "e_p) { this->dialect_options.state_machine_options.quote.Set(quote_str[0]); } -NewLineIdentifier CSVReaderOptions::GetNewline() const { - return dialect_options.state_machine_options.new_line.GetValue(); +string CSVReaderOptions::GetComment() const { + return std::string(1, this->dialect_options.state_machine_options.comment.GetValue()); +} + +void CSVReaderOptions::SetComment(const string &comment_p) { + auto comment_str = comment_p; + if (comment_str.size() > 1) { + throw InvalidInputException("The comment option cannot exceed a size of 1 byte."); + } + if (comment_str.empty()) { + comment_str = string("\0", 1); + } + this->dialect_options.state_machine_options.comment.Set(comment_str[0]); +} + +string CSVReaderOptions::GetNewline() const { + switch (dialect_options.state_machine_options.new_line.GetValue()) { + case NewLineIdentifier::CARRY_ON: + return "\\r\\n"; + case NewLineIdentifier::SINGLE_R: + return "\\r"; + case NewLineIdentifier::SINGLE_N: + return "\\n"; + case NewLineIdentifier::NOT_SET: + return ""; + default: + throw NotImplementedException("New line type not supported"); + } } void CSVReaderOptions::SetNewline(const string &input) { - if (input == "\\n" || input == "\\r") { - dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::SINGLE); + if (input == "\\n") { + dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::SINGLE_N); + } else if (input == "\\r") { + dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::SINGLE_R); } else if (input == "\\r\\n") { dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::CARRY_ON); } else { @@ -290,6 +318,8 @@ bool CSVReaderOptions::SetBaseOption(const string &loption, const Value &value, SetDelimiter(ParseString(value, loption)); } else if (loption == "quote") { SetQuote(ParseString(value, loption)); + } else if (loption == "comment") { + SetComment(ParseString(value, loption)); } else if (loption == "new_line") { SetNewline(ParseString(value, loption)); } else if (loption == "escape") { @@ -365,15 +395,16 @@ bool CSVReaderOptions::WasTypeManuallySet(idx_t i) const { return was_type_manually_set[i]; } -string CSVReaderOptions::ToString() const { +string CSVReaderOptions::ToString(const string ¤t_file_path) const { auto &delimiter = dialect_options.state_machine_options.delimiter; auto "e = dialect_options.state_machine_options.quote; auto &escape = dialect_options.state_machine_options.escape; + auto &comment = dialect_options.state_machine_options.comment; auto &new_line = dialect_options.state_machine_options.new_line; auto &skip_rows = dialect_options.skip_rows; auto &header = dialect_options.header; - string error = " file=" + file_path + "\n "; + string error = " file=" + current_file_path + "\n "; // Let's first print options that can either be set by the user or by the sniffer // delimiter error += FormatOptionLine("delimiter", delimiter); @@ -387,6 +418,8 @@ string CSVReaderOptions::ToString() const { error += FormatOptionLine("header", header); // skip_rows error += FormatOptionLine("skip_rows", skip_rows); + // comment + error += FormatOptionLine("comment", comment); // date format error += FormatOptionLine("date_format", dialect_options.date_format.at(LogicalType::DATE)); // timestamp format @@ -444,8 +477,7 @@ bool StoreUserDefinedParameter(string &option) { } return true; } -void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientContext &context, - vector &return_types, vector &names) { +void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientContext &context) { map ordered_user_defined_parameters; for (auto &kv : in) { if (MultiFileReader().ParseOption(kv.first, kv.second, file_options, context)) { @@ -457,6 +489,10 @@ void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientCont ordered_user_defined_parameters[loption] = kv.second.ToSQLString(); } if (loption == "columns") { + if (!name_list.empty()) { + throw BinderException("read_csv_auto column_names/names can only be supplied once"); + } + columns_set = true; auto &child_type = kv.second.type(); if (child_type.id() != LogicalTypeId::STRUCT) { throw BinderException("read_csv columns requires a struct as input"); @@ -466,13 +502,14 @@ void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientCont for (idx_t i = 0; i < struct_children.size(); i++) { auto &name = StructType::GetChildName(child_type, i); auto &val = struct_children[i]; - names.push_back(name); + name_list.push_back(name); if (val.type().id() != LogicalTypeId::VARCHAR) { throw BinderException("read_csv requires a type specification as string"); } - return_types.emplace_back(TransformStringToLogicalType(StringValue::Get(val), context)); + sql_types_per_column[name] = i; + sql_type_list.emplace_back(TransformStringToLogicalType(StringValue::Get(val), context)); } - if (names.empty()) { + if (name_list.empty()) { throw BinderException("read_csv requires at least a single column as input!"); } } else if (loption == "auto_type_candidates") { @@ -557,7 +594,7 @@ void CSVReaderOptions::FromNamedParameters(named_parameter_map_t &in, ClientCont } else if (loption == "normalize_names") { normalize_names = BooleanValue::Get(kv.second); } else { - SetReadOption(loption, kv.second, names); + SetReadOption(loption, kv.second, name_list); } } for (auto &udf_parameter : ordered_user_defined_parameters) { @@ -573,12 +610,13 @@ void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { auto &delimiter = dialect_options.state_machine_options.delimiter; auto "e = dialect_options.state_machine_options.quote; auto &escape = dialect_options.state_machine_options.escape; + auto &comment = dialect_options.state_machine_options.comment; auto &header = dialect_options.header; if (delimiter.IsSetByUser()) { named_params["delim"] = Value(GetDelimiter()); } if (dialect_options.state_machine_options.new_line.IsSetByUser()) { - named_params["newline"] = Value(EnumUtil::ToString(GetNewline())); + named_params["new_line"] = Value(GetNewline()); } if (quote.IsSetByUser()) { named_params["quote"] = Value(GetQuote()); @@ -586,12 +624,15 @@ void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { if (escape.IsSetByUser()) { named_params["escape"] = Value(GetEscape()); } + if (comment.IsSetByUser()) { + named_params["comment"] = Value(GetComment()); + } if (header.IsSetByUser()) { named_params["header"] = Value(GetHeader()); } named_params["max_line_size"] = Value::BIGINT(NumericCast(maximum_line_size)); if (dialect_options.skip_rows.IsSetByUser()) { - named_params["skip"] = Value::BIGINT(GetSkipRows()); + named_params["skip"] = Value::UBIGINT(GetSkipRows()); } named_params["null_padding"] = Value::BOOLEAN(null_padding); named_params["parallel"] = Value::BOOLEAN(parallel); @@ -605,7 +646,8 @@ void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { } named_params["normalize_names"] = Value::BOOLEAN(normalize_names); - if (!name_list.empty() && !named_params.count("column_names") && !named_params.count("names")) { + if (!name_list.empty() && !named_params.count("columns") && !named_params.count("column_names") && + !named_params.count("names")) { named_params["column_names"] = StringVectorToValue(name_list); } named_params["all_varchar"] = Value::BOOLEAN(all_varchar); diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp index d70d5093..e2a25d9f 100644 --- a/src/duckdb/src/execution/operator/filter/physical_filter.cpp +++ b/src/duckdb/src/execution/operator/filter/physical_filter.cpp @@ -31,7 +31,7 @@ class FilterState : public CachingOperatorState { public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, executor, "filter", 0); + context.thread.profiler.Flush(op); } }; @@ -52,10 +52,10 @@ OperatorResultType PhysicalFilter::ExecuteInternal(ExecutionContext &context, Da return OperatorResultType::NEED_MORE_INPUT; } -string PhysicalFilter::ParamsToString() const { - auto result = expression->GetName(); - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("EC: %llu", estimated_cardinality); +InsertionOrderPreservingMap PhysicalFilter::ParamsToString() const { + InsertionOrderPreservingMap result; + result["__expression__"] = expression->GetName(); + SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp index 2c0d6ac8..c489124e 100644 --- a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp @@ -9,27 +9,6 @@ namespace duckdb { PhysicalBatchCollector::PhysicalBatchCollector(PreparedStatementData &data) : PhysicalResultCollector(data) { } -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BatchCollectorGlobalState : public GlobalSinkState { -public: - BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { - } - - mutex glock; - BatchedDataCollection data; - unique_ptr result; -}; - -class BatchCollectorLocalState : public LocalSinkState { -public: - BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { - } - - BatchedDataCollection data; -}; - SinkResultType PhysicalBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &state = input.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp new file mode 100644 index 00000000..f881da2a --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp @@ -0,0 +1,109 @@ +#include "duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp" + +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/materialized_query_result.hpp" +#include "duckdb/main/buffered_data/buffered_data.hpp" +#include "duckdb/main/buffered_data/batched_buffered_data.hpp" +#include "duckdb/main/stream_query_result.hpp" + +namespace duckdb { + +PhysicalBufferedBatchCollector::PhysicalBufferedBatchCollector(PreparedStatementData &data) + : PhysicalResultCollector(data) { +} + +//===--------------------------------------------------------------------===// +// Sink +//===--------------------------------------------------------------------===// +class BufferedBatchCollectorGlobalState : public GlobalSinkState { +public: + weak_ptr context; + shared_ptr buffered_data; +}; + +BufferedBatchCollectorLocalState::BufferedBatchCollectorLocalState() { +} + +SinkResultType PhysicalBufferedBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, + OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + lstate.current_batch = lstate.partition_info.batch_index.GetIndex(); + auto batch = lstate.partition_info.batch_index.GetIndex(); + auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); + + auto &buffered_data = gstate.buffered_data->Cast(); + buffered_data.UpdateMinBatchIndex(min_batch_index); + + if (buffered_data.ShouldBlockBatch(batch)) { + auto callback_state = input.interrupt_state; + buffered_data.BlockSink(callback_state, batch); + return SinkResultType::BLOCKED; + } + + // FIXME: if we want to make this more accurate, we should grab a reservation on the buffer space + // while we're unlocked some other thread could also append, causing us to potentially cross our buffer size + + buffered_data.Append(chunk, batch); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkNextBatchType PhysicalBufferedBatchCollector::NextBatch(ExecutionContext &context, + OperatorSinkNextBatchInput &input) const { + + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + auto batch = lstate.current_batch; + auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); + auto new_index = lstate.partition_info.batch_index.GetIndex(); + + auto &buffered_data = gstate.buffered_data->Cast(); + buffered_data.CompleteBatch(batch); + lstate.current_batch = new_index; + // FIXME: this can move from the buffer to the read queue, increasing the 'read_queue_byte_count' + // We might want to block here if 'read_queue_byte_count' has already reached the ReadQueueCapacity() + // So we don't completely disregard the 'streaming_buffer_size' that was set + buffered_data.UpdateMinBatchIndex(min_batch_index); + return SinkNextBatchType::READY; +} + +SinkCombineResultType PhysicalBufferedBatchCollector::Combine(ExecutionContext &context, + OperatorSinkCombineInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); + auto &buffered_data = gstate.buffered_data->Cast(); + + // FIXME: this can move from the buffer to the read queue, increasing the 'read_queue_byte_count' + // We might want to block here if 'read_queue_byte_count' has already reached the ReadQueueCapacity() + // So we don't completely disregard the 'streaming_buffer_size' that was set + buffered_data.UpdateMinBatchIndex(min_batch_index); + return SinkCombineResultType::FINISHED; +} + +unique_ptr PhysicalBufferedBatchCollector::GetLocalSinkState(ExecutionContext &context) const { + auto state = make_uniq(); + return std::move(state); +} + +unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(ClientContext &context) const { + auto state = make_uniq(); + state->context = context.shared_from_this(); + state->buffered_data = make_shared_ptr(state->context); + return std::move(state); +} + +unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) { + auto &gstate = state.Cast(); + auto cc = gstate.context.lock(); + auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), + gstate.buffered_data); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp index 9f2ac70d..6a036109 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp @@ -19,31 +19,23 @@ class BufferedCollectorGlobalState : public GlobalSinkState { shared_ptr buffered_data; }; -class BufferedCollectorLocalState : public LocalSinkState { -public: - bool blocked = false; -}; +class BufferedCollectorLocalState : public LocalSinkState {}; SinkResultType PhysicalBufferedCollector::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); + (void)lstate; lock_guard l(gstate.glock); auto &buffered_data = gstate.buffered_data->Cast(); - if (!lstate.blocked || buffered_data.BufferIsFull()) { - lstate.blocked = true; + if (buffered_data.BufferIsFull()) { auto callback_state = input.interrupt_state; - auto blocked_sink = BlockedSink(callback_state, chunk.size()); - buffered_data.BlockSink(blocked_sink); + buffered_data.BlockSink(callback_state); return SinkResultType::BLOCKED; } - - auto to_append = make_uniq(); - to_append->Initialize(Allocator::DefaultAllocator(), chunk.GetTypes()); - chunk.Copy(*to_append, 0); - buffered_data.Append(std::move(to_append)); + buffered_data.Append(chunk); return SinkResultType::NEED_MORE_INPUT; } diff --git a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp index b5a43bfe..8d2c1281 100644 --- a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp @@ -21,7 +21,7 @@ SinkFinalizeType PhysicalExplainAnalyze::Finalize(Pipeline &pipeline, Event &eve OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); auto &profiler = QueryProfiler::Get(context); - gstate.analyzed_plan = profiler.ToString(); + gstate.analyzed_plan = profiler.ToString(format); return SinkFinalizeType::READY; } diff --git a/src/duckdb/src/execution/operator/helper/physical_load.cpp b/src/duckdb/src/execution/operator/helper/physical_load.cpp index e20b5af0..5f0e7a02 100644 --- a/src/duckdb/src/execution/operator/helper/physical_load.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_load.cpp @@ -16,15 +16,23 @@ static void InstallFromRepository(ClientContext &context, const LoadInfo &info) repository = ExtensionRepository::GetRepositoryByUrl(info.repository); } - ExtensionHelper::InstallExtension(context, info.filename, info.load_type == LoadType::FORCE_INSTALL, repository, - true, info.version); + ExtensionInstallOptions options; + options.force_install = info.load_type == LoadType::FORCE_INSTALL; + options.throw_on_origin_mismatch = true; + options.version = info.version; + options.repository = repository; + + ExtensionHelper::InstallExtension(context, info.filename, options); } SourceResultType PhysicalLoad::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { if (info->load_type == LoadType::INSTALL || info->load_type == LoadType::FORCE_INSTALL) { if (info->repository.empty()) { - ExtensionHelper::InstallExtension(context.client, info->filename, - info->load_type == LoadType::FORCE_INSTALL, nullptr, true, info->version); + ExtensionInstallOptions options; + options.force_install = info->load_type == LoadType::FORCE_INSTALL; + options.throw_on_origin_mismatch = true; + options.version = info->version; + ExtensionHelper::InstallExtension(context.client, info->filename, options); } else { InstallFromRepository(context.client, *info); } diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp index c5f0309f..2492a3a8 100644 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -9,22 +9,6 @@ PhysicalMaterializedCollector::PhysicalMaterializedCollector(PreparedStatementDa : PhysicalResultCollector(data), parallel(parallel) { } -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MaterializedCollectorGlobalState : public GlobalSinkState { -public: - mutex glock; - unique_ptr collection; - shared_ptr context; -}; - -class MaterializedCollectorLocalState : public LocalSinkState { -public: - unique_ptr collection; - ColumnDataAppendState append_state; -}; - SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &lstate = input.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp index b253025f..869db44f 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp @@ -91,8 +91,10 @@ SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, Dat return SourceResultType::HAVE_MORE_OUTPUT; } -string PhysicalReservoirSample::ParamsToString() const { - return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows"); +InsertionOrderPreservingMap PhysicalReservoirSample::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Sample Size"] = options->sample_size.ToString() + (options->is_percentage ? "%" : " rows"); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp index 916670cf..b6a219a4 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -20,6 +20,11 @@ void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig & } SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + if (scope == SetScope::VARIABLE) { + auto &client_config = ClientConfig::GetConfig(context.client); + client_config.ResetUserVariable(name); + return SourceResultType::FINISHED; + } auto &config = DBConfig::GetConfig(context.client); config.CheckLock(name); auto option = DBConfig::GetOptionByName(name); diff --git a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp index 8b2bbdf8..5dcb356e 100644 --- a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp @@ -1,12 +1,14 @@ #include "duckdb/execution/operator/helper/physical_result_collector.hpp" #include "duckdb/execution/operator/helper/physical_batch_collector.hpp" +#include "duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp" #include "duckdb/execution/operator/helper/physical_materialized_collector.hpp" #include "duckdb/execution/operator/helper/physical_buffered_collector.hpp" #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/prepared_statement_data.hpp" #include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/main/query_result.hpp" #include "duckdb/parallel/pipeline.hpp" namespace duckdb { @@ -35,7 +37,7 @@ unique_ptr PhysicalResultCollector::GetResultCollector( // we care about maintaining insertion order and the sources all support batch indexes // use a batch collector if (data.is_streaming) { - return make_uniq_base(data, false); + return make_uniq_base(data); } return make_uniq_base(data); } diff --git a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp new file mode 100644 index 00000000..37482d2c --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp @@ -0,0 +1,39 @@ +#include "duckdb/execution/operator/helper/physical_set_variable.hpp" +#include "duckdb/main/client_config.hpp" + +namespace duckdb { + +PhysicalSetVariable::PhysicalSetVariable(string name_p, idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::SET_VARIABLE, {LogicalType::BOOLEAN}, estimated_cardinality), + name(std::move(name_p)) { +} + +SourceResultType PhysicalSetVariable::GetData(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + return SourceResultType::FINISHED; +} + +class SetVariableGlobalState : public GlobalSinkState { +public: + SetVariableGlobalState() { + } + + bool is_set = false; +}; + +unique_ptr PhysicalSetVariable::GetGlobalSinkState(ClientContext &context) const { + return make_uniq(); +} + +SinkResultType PhysicalSetVariable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { + auto &gstate = input.global_state.Cast(); + if (chunk.size() != 1 || gstate.is_set) { + throw InvalidInputException("PhysicalSetVariable can only handle a single value"); + } + auto &config = ClientConfig::GetConfig(context.client); + config.SetUserVariable(name, chunk.GetValue(0, 0)); + gstate.is_set = true; + return SinkResultType::NEED_MORE_INPUT; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp index cda2fa25..30925624 100644 --- a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp @@ -68,8 +68,10 @@ OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, D return OperatorResultType::NEED_MORE_INPUT; } -string PhysicalStreamingSample::ParamsToString() const { - return EnumUtil::ToString(method) + ": " + to_string(100 * percentage) + "%"; +InsertionOrderPreservingMap PhysicalStreamingSample::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Sample Method"] = EnumUtil::ToString(method) + ": " + to_string(100 * percentage) + "%"; + return result; } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp index 9361cbd5..a8ad410d 100644 --- a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp @@ -1,11 +1,12 @@ #include "duckdb/execution/operator/helper/physical_transaction.hpp" + +#include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/database_manager.hpp" #include "duckdb/main/valid_checker.hpp" -#include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/transaction/meta_transaction.hpp" #include "duckdb/transaction/transaction_manager.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/database_manager.hpp" namespace duckdb { @@ -28,6 +29,9 @@ SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChu // preserving the transaction context for the next query client.transaction.SetAutoCommit(false); auto &config = DBConfig::GetConfig(context.client); + if (info->modifier == TransactionModifierType::TRANSACTION_READ_ONLY) { + client.transaction.SetReadOnly(); + } if (config.options.immediate_transaction_mode) { // if immediate transaction mode is enabled then start all transactions immediately auto databases = DatabaseManager::Get(client).GetDatabases(client); @@ -53,8 +57,15 @@ SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChu if (client.transaction.IsAutoCommit()) { throw TransactionException("cannot rollback - no transaction is active"); } else { - // explicitly rollback the current transaction - client.transaction.Rollback(); + // Explicitly rollback the current transaction + // If it is because of an invalidated transaction, we need to rollback with an error + auto &valid_checker = ValidChecker::Get(client.transaction.ActiveTransaction()); + if (valid_checker.IsInvalidated()) { + ErrorData error(ExceptionType::TRANSACTION, valid_checker.InvalidatedMessage()); + client.transaction.Rollback(error); + } else { + client.transaction.Rollback(nullptr); + } } break; } diff --git a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp index 53d7de3d..645c702a 100644 --- a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp +++ b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp @@ -10,7 +10,7 @@ void OuterJoinMarker::Initialize(idx_t count_p) { return; } this->count = count_p; - found_match = make_unsafe_uniq_array(count); + found_match = make_unsafe_uniq_array_uninitialized(count); Reset(); } diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp index bfac5402..f5ced297 100644 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -25,7 +25,7 @@ bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { } // and for duplicate_checking - bitmap_build_idx = make_unsafe_uniq_array(build_size); + bitmap_build_idx = make_unsafe_uniq_array_uninitialized(build_size); memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false // Now fill columns with build data diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index b0e4dcc4..91dda01e 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -169,7 +169,7 @@ SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, Cl } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline); + auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline, *this); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp index f1f870ab..1b6eb619 100644 --- a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp @@ -203,10 +203,11 @@ OperatorResultType PhysicalBlockwiseNLJoin::ExecuteInternal(ExecutionContext &co return OperatorResultType::HAVE_MORE_OUTPUT; } -string PhysicalBlockwiseNLJoin::ParamsToString() const { - string extra_info = EnumUtil::ToString(join_type) + "\n"; - extra_info += condition->GetName(); - return extra_info; +InsertionOrderPreservingMap PhysicalBlockwiseNLJoin::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Join Type"] = EnumUtil::ToString(join_type); + result["Condition"] = condition->GetName(); + return result; } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp index ba0c4e5a..1091e216 100644 --- a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp @@ -7,32 +7,69 @@ namespace duckdb { PhysicalComparisonJoin::PhysicalComparisonJoin(LogicalOperator &op, PhysicalOperatorType type, vector conditions_p, JoinType join_type, idx_t estimated_cardinality) - : PhysicalJoin(op, type, join_type, estimated_cardinality) { - conditions.resize(conditions_p.size()); - // we reorder conditions so the ones with COMPARE_EQUAL occur first - idx_t equal_position = 0; - idx_t other_position = conditions_p.size() - 1; - for (idx_t i = 0; i < conditions_p.size(); i++) { - if (conditions_p[i].comparison == ExpressionType::COMPARE_EQUAL || - conditions_p[i].comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - // COMPARE_EQUAL and COMPARE_NOT_DISTINCT_FROM, move to the start - conditions[equal_position++] = std::move(conditions_p[i]); - } else { - // other expression, move to the end - conditions[other_position--] = std::move(conditions_p[i]); + : PhysicalJoin(op, type, join_type, estimated_cardinality), conditions(std::move(conditions_p)) { + ReorderConditions(conditions); +} + +InsertionOrderPreservingMap PhysicalComparisonJoin::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Join Type"] = EnumUtil::ToString(join_type); + string condition_info; + for (idx_t i = 0; i < conditions.size(); i++) { + auto &join_condition = conditions[i]; + if (i > 0) { + condition_info += "\n"; } + condition_info += + StringUtil::Format("%s %s %s", join_condition.left->GetName(), + ExpressionTypeToOperator(join_condition.comparison), join_condition.right->GetName()); + // string op = ExpressionTypeToOperator(it.comparison); + // extra_info += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; } + result["Conditions"] = condition_info; + SetEstimatedCardinality(result, estimated_cardinality); + return result; } -string PhysicalComparisonJoin::ParamsToString() const { - string extra_info = EnumUtil::ToString(join_type) + "\n"; - for (auto &it : conditions) { - string op = ExpressionTypeToOperator(it.comparison); - extra_info += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; +void PhysicalComparisonJoin::ReorderConditions(vector &conditions) { + // we reorder conditions so the ones with COMPARE_EQUAL occur first + // check if this is already the case + bool is_ordered = true; + bool seen_non_equal = false; + for (auto &cond : conditions) { + if (cond.comparison == ExpressionType::COMPARE_EQUAL || + cond.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + if (seen_non_equal) { + is_ordered = false; + break; + } + } else { + seen_non_equal = true; + } + } + if (is_ordered) { + // no need to re-order + return; + } + // gather lists of equal/other conditions + vector equal_conditions; + vector other_conditions; + for (auto &cond : conditions) { + if (cond.comparison == ExpressionType::COMPARE_EQUAL || + cond.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + equal_conditions.push_back(std::move(cond)); + } else { + other_conditions.push_back(std::move(cond)); + } + } + conditions.clear(); + // reconstruct the sorted conditions + for (auto &cond : equal_conditions) { + conditions.push_back(std::move(cond)); + } + for (auto &cond : other_conditions) { + conditions.push_back(std::move(cond)); } - extra_info += "\n[INFOSEPARATOR]\n"; - extra_info += StringUtil::Format("EC: %llu\n", estimated_cardinality); - return extra_info; } void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool has_null, DataChunk &input, @@ -80,4 +117,5 @@ void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool h } } } + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_delim_join.cpp index 5d9f2806..1da39904 100644 --- a/src/duckdb/src/execution/operator/join/physical_delim_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_delim_join.cpp @@ -6,9 +6,10 @@ namespace duckdb { PhysicalDelimJoin::PhysicalDelimJoin(PhysicalOperatorType type, vector types, unique_ptr original_join, - vector> delim_scans, idx_t estimated_cardinality) + vector> delim_scans, idx_t estimated_cardinality, + optional_idx delim_idx) : PhysicalOperator(type, std::move(types), estimated_cardinality), join(std::move(original_join)), - delim_scans(std::move(delim_scans)) { + delim_scans(std::move(delim_scans)), delim_idx(delim_idx) { D_ASSERT(type == PhysicalOperatorType::LEFT_DELIM_JOIN || type == PhysicalOperatorType::RIGHT_DELIM_JOIN); } @@ -22,8 +23,10 @@ vector> PhysicalDelimJoin::GetChildren() const return result; } -string PhysicalDelimJoin::ParamsToString() const { - return join->ParamsToString(); +InsertionOrderPreservingMap PhysicalDelimJoin::ParamsToString() const { + auto result = join->ParamsToString(); + result["Delim Index"] = StringUtil::Format("%llu", delim_idx.GetIndex()); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index e9d880ec..3deb27b6 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -1,18 +1,22 @@ #include "duckdb/execution/operator/join/physical_hash_join.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/radix_partitioning.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/query_profiler.hpp" #include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/executor_task.hpp" #include "duckdb/parallel/interrupt.hpp" #include "duckdb/parallel/pipeline.hpp" #include "duckdb/parallel/thread_context.hpp" -#include "duckdb/parallel/executor_task.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" +#include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/storage_manager.hpp" #include "duckdb/storage/temporary_memory_manager.hpp" @@ -23,11 +27,14 @@ PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr right, vector cond, JoinType join_type, const vector &left_projection_map, const vector &right_projection_map, vector delim_types, idx_t estimated_cardinality, - PerfectHashJoinStats perfect_join_stats) + PerfectHashJoinStats perfect_join_stats, + unique_ptr pushdown_info_p) : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), delim_types(std::move(delim_types)), perfect_join_statistics(std::move(perfect_join_stats)) { D_ASSERT(left_projection_map.empty()); + filter_pushdown = std::move(pushdown_info_p); + children.push_back(std::move(left)); children.push_back(std::move(right)); @@ -79,30 +86,51 @@ PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr right, vector cond, JoinType join_type, idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_state) : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {}, - estimated_cardinality, std::move(perfect_join_state)) { + estimated_cardinality, std::move(perfect_join_state), nullptr) { } //===--------------------------------------------------------------------===// // Sink //===--------------------------------------------------------------------===// +JoinFilterGlobalState::~JoinFilterGlobalState() { +} + +JoinFilterLocalState::~JoinFilterLocalState() { +} + +unique_ptr JoinFilterPushdownInfo::GetGlobalState(ClientContext &context, + const PhysicalOperator &op) const { + // clear any previously set filters + // we can have previous filters for this operator in case of e.g. recursive CTEs + dynamic_filters->ClearFilters(op); + auto result = make_uniq(); + result->global_aggregate_state = + make_uniq(BufferAllocator::Get(context), min_max_aggregates); + return result; +} + class HashJoinGlobalSinkState : public GlobalSinkState { public: - HashJoinGlobalSinkState(const PhysicalHashJoin &op, ClientContext &context_p) - : context(context_p), num_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), - temporary_memory_update_count(0), + HashJoinGlobalSinkState(const PhysicalHashJoin &op_p, ClientContext &context_p) + : context(context_p), op(op_p), + num_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), finalized(false), - scanned_data(false) { + active_local_states(0), total_size(0), max_partition_size(0), max_partition_count(0), scanned_data(false) { hash_table = op.InitializeHashTable(context); - // for perfect hash join + // For perfect hash join perfect_join_executor = make_uniq(op, *hash_table, op.perfect_join_statistics); - // for external hash join + // For external hash join external = ClientConfig::GetConfig(context).force_external; // Set probe types const auto &payload_types = op.children[0]->types; probe_types.insert(probe_types.end(), op.condition_types.begin(), op.condition_types.end()); probe_types.insert(probe_types.end(), payload_types.begin(), payload_types.end()); probe_types.emplace_back(LogicalType::HASH); + + if (op.filter_pushdown) { + global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + } } void ScheduleFinalize(Pipeline &pipeline, Event &event); @@ -110,9 +138,9 @@ class HashJoinGlobalSinkState : public GlobalSinkState { public: ClientContext &context; + const PhysicalHashJoin &op; const idx_t num_threads; - atomic temporary_memory_update_count; //! Temporary memory state for managing this operator's memory usage unique_ptr temporary_memory_state; @@ -121,13 +149,17 @@ class HashJoinGlobalSinkState : public GlobalSinkState { //! The perfect hash join executor (if any) unique_ptr perfect_join_executor; //! Whether or not the hash table has been finalized - bool finalized = false; + bool finalized; + //! The number of active local states + atomic active_local_states; - //! Whether we are doing an external join + //! Whether we are doing an external + some sizes bool external; + idx_t total_size; + idx_t max_partition_size; + idx_t max_partition_count; //! Hash tables built by each thread - mutex lock; vector> local_hash_tables; //! Excess probe data gathered during Sink @@ -136,12 +168,20 @@ class HashJoinGlobalSinkState : public GlobalSinkState { //! Whether or not we have started scanning data using GetData atomic scanned_data; + + unique_ptr global_filter_state; }; +unique_ptr JoinFilterPushdownInfo::GetLocalState(JoinFilterGlobalState &gstate) const { + auto result = make_uniq(); + result->local_aggregate_state = make_uniq(*gstate.global_aggregate_state); + return result; +} + class HashJoinLocalSinkState : public LocalSinkState { public: - HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context) - : join_key_executor(context), chunk_count(0) { + HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context, HashJoinGlobalSinkState &gstate) + : join_key_executor(context) { auto &allocator = BufferAllocator::Get(context); for (auto &cond : op.conditions) { @@ -155,6 +195,12 @@ class HashJoinLocalSinkState : public LocalSinkState { hash_table = op.InitializeHashTable(context); hash_table->GetSinkCollection().InitializeAppendState(append_state); + + gstate.active_local_states++; + + if (op.filter_pushdown) { + local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); + } } public: @@ -168,14 +214,11 @@ class HashJoinLocalSinkState : public LocalSinkState { //! Thread-local HT unique_ptr hash_table; - //! For updating the temporary memory state - idx_t chunk_count; - static constexpr const idx_t CHUNK_COUNT_UPDATE_INTERVAL = 60; + unique_ptr local_filter_state; }; unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &context) const { - auto result = make_uniq(BufferManager::GetBufferManager(context), conditions, payload_types, - join_type, rhs_output_columns); + auto result = make_uniq(context, conditions, payload_types, join_type, rhs_output_columns); if (!delim_types.empty() && join_type == JoinType::MARK) { // correlated MARK join if (delim_types.size() + 1 == conditions.size()) { @@ -229,7 +272,19 @@ unique_ptr PhysicalHashJoin::GetGlobalSinkState(ClientContext & } unique_ptr PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context.client); + auto &gstate = sink_state->Cast(); + return make_uniq(*this, context.client, gstate); +} + +void JoinFilterPushdownInfo::Sink(DataChunk &chunk, JoinFilterLocalState &lstate) const { + // if we are pushing any filters into a probe-side, compute the min/max over the columns that we are pushing + for (idx_t pushdown_idx = 0; pushdown_idx < filters.size(); pushdown_idx++) { + auto &pushdown = filters[pushdown_idx]; + for (idx_t i = 0; i < 2; i++) { + idx_t aggr_idx = pushdown_idx * 2 + i; + lstate.local_aggregate_state->Sink(chunk, pushdown.join_condition, aggr_idx); + } + } } SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { @@ -239,6 +294,10 @@ SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chun lstate.join_keys.Reset(); lstate.join_key_executor.Execute(chunk, lstate.join_keys); + if (filter_pushdown) { + filter_pushdown->Sink(lstate.join_keys, *lstate.local_filter_state); + } + // build the HT auto &ht = *lstate.hash_table; if (payload_types.empty()) { @@ -255,29 +314,31 @@ SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chun ht.Build(lstate.append_state, lstate.join_keys, lstate.payload_chunk); } - if (++lstate.chunk_count % HashJoinLocalSinkState::CHUNK_COUNT_UPDATE_INTERVAL == 0) { - auto &gstate = input.global_state.Cast(); - if (++gstate.temporary_memory_update_count % gstate.num_threads == 0) { - auto &sink_collection = lstate.hash_table->GetSinkCollection(); - auto ht_size = sink_collection.SizeInBytes() + JoinHashTable::PointerTableSize(sink_collection.Count()); - gstate.temporary_memory_state->SetRemainingSize(context.client, gstate.num_threads * ht_size); - } - } - return SinkResultType::NEED_MORE_INPUT; } +void JoinFilterPushdownInfo::Combine(JoinFilterGlobalState &gstate, JoinFilterLocalState &lstate) const { + gstate.global_aggregate_state->Combine(*lstate.local_aggregate_state); +} + SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - if (lstate.hash_table) { - lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); - lock_guard local_ht_lock(gstate.lock); - gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); + + lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); + auto guard = gstate.Lock(); + gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); + if (gstate.local_hash_tables.size() == gstate.active_local_states) { + // Set to 0 until PrepareFinalize + gstate.temporary_memory_state->SetZero(); } + auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.join_key_executor, "join_key_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); + if (filter_pushdown) { + filter_pushdown->Combine(*gstate.global_filter_state, *lstate.local_filter_state); + } return SinkCombineResultType::FINISHED; } @@ -285,11 +346,48 @@ SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, Opera //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// +static idx_t GetTupleWidth(const vector &types, bool &all_constant) { + idx_t tuple_width = 0; + all_constant = true; + for (auto &type : types) { + tuple_width += GetTypeIdSize(type.InternalType()); + all_constant &= TypeIsConstantSize(type.InternalType()); + } + return tuple_width + AlignValue(types.size()) / 8 + GetTypeIdSize(PhysicalType::UINT64); +} + +static idx_t GetPartitioningSpaceRequirement(ClientContext &context, const vector &types, + const idx_t radix_bits, const idx_t num_threads) { + auto &buffer_manager = BufferManager::GetBufferManager(context); + bool all_constant; + idx_t tuple_width = GetTupleWidth(types, all_constant); + + auto tuples_per_block = buffer_manager.GetBlockSize() / tuple_width; + auto blocks_per_chunk = (STANDARD_VECTOR_SIZE + tuples_per_block) / tuples_per_block + 1; + if (!all_constant) { + blocks_per_chunk += 2; + } + auto size_per_partition = blocks_per_chunk * buffer_manager.GetBlockAllocSize(); + auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + + return num_threads * num_partitions * size_per_partition; +} + +void PhysicalHashJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &global_state) const { + auto &gstate = global_state.Cast(); + auto &ht = *gstate.hash_table; + gstate.total_size = + ht.GetTotalSize(gstate.local_hash_tables, gstate.max_partition_size, gstate.max_partition_count); + bool all_constant; + gstate.temporary_memory_state->SetMaterializationPenalty(GetTupleWidth(children[0]->types, all_constant)); + gstate.temporary_memory_state->SetRemainingSize(gstate.total_size); +} + class HashJoinFinalizeTask : public ExecutorTask { public: HashJoinFinalizeTask(shared_ptr event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, - idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p) - : ExecutorTask(context, std::move(event_p)), sink(sink_p), chunk_idx_from(chunk_idx_from_p), + idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p, const PhysicalOperator &op_p) + : ExecutorTask(context, std::move(event_p), op_p), sink(sink_p), chunk_idx_from(chunk_idx_from_p), chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) { } @@ -321,11 +419,11 @@ class HashJoinFinalizeEvent : public BasePipelineEvent { vector> finalize_tasks; auto &ht = *sink.hash_table; const auto chunk_count = ht.GetDataCollection().ChunkCount(); - const auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); + const auto num_threads = NumericCast(sink.num_threads); if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) { // Single-threaded finalize finalize_tasks.push_back( - make_uniq(shared_from_this(), context, sink, 0U, chunk_count, false)); + make_uniq(shared_from_this(), context, sink, 0U, chunk_count, false, sink.op)); } else { // Parallel finalize auto chunks_per_thread = MaxValue((chunk_count + num_threads - 1) / num_threads, 1); @@ -335,7 +433,7 @@ class HashJoinFinalizeEvent : public BasePipelineEvent { auto chunk_idx_from = chunk_idx; auto chunk_idx_to = MinValue(chunk_idx_from + chunks_per_thread, chunk_count); finalize_tasks.push_back(make_uniq(shared_from_this(), context, sink, - chunk_idx_from, chunk_idx_to, true)); + chunk_idx_from, chunk_idx_to, true, sink.op)); chunk_idx = chunk_idx_to; if (chunk_idx == chunk_count) { break; @@ -364,7 +462,7 @@ void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) } void HashJoinGlobalSinkState::InitializeProbeSpill() { - lock_guard guard(lock); + auto guard = Lock(); if (!probe_spill) { probe_spill = make_uniq(*hash_table, context, probe_types); } @@ -373,8 +471,8 @@ void HashJoinGlobalSinkState::InitializeProbeSpill() { class HashJoinRepartitionTask : public ExecutorTask { public: HashJoinRepartitionTask(shared_ptr event_p, ClientContext &context, JoinHashTable &global_ht, - JoinHashTable &local_ht) - : ExecutorTask(context, std::move(event_p)), global_ht(global_ht), local_ht(local_ht) { + JoinHashTable &local_ht, const PhysicalOperator &op_p) + : ExecutorTask(context, std::move(event_p), op_p), global_ht(global_ht), local_ht(local_ht) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { @@ -390,17 +488,19 @@ class HashJoinRepartitionTask : public ExecutorTask { class HashJoinRepartitionEvent : public BasePipelineEvent { public: - HashJoinRepartitionEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink, + HashJoinRepartitionEvent(Pipeline &pipeline_p, const PhysicalHashJoin &op_p, HashJoinGlobalSinkState &sink, vector> &local_hts) - : BasePipelineEvent(pipeline_p), sink(sink), local_hts(local_hts) { + : BasePipelineEvent(pipeline_p), op(op_p), sink(sink), local_hts(local_hts) { } + const PhysicalHashJoin &op; HashJoinGlobalSinkState &sink; vector> &local_hts; public: void Schedule() override { D_ASSERT(sink.hash_table->GetRadixBits() > JoinHashTable::INITIAL_RADIX_BITS); + auto block_size = sink.hash_table->buffer_manager.GetBlockSize(); idx_t total_size = 0; idx_t total_count = 0; @@ -409,14 +509,14 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { total_size += sink_collection.SizeInBytes(); total_count += sink_collection.Count(); } - auto total_blocks = NumericCast((double(total_size) + Storage::BLOCK_SIZE - 1) / Storage::BLOCK_SIZE); + auto total_blocks = (total_size + block_size - 1) / block_size; auto count_per_block = total_count / total_blocks; auto blocks_per_vector = MaxValue(STANDARD_VECTOR_SIZE / count_per_block, 2); // Assume 8 blocks per partition per thread (4 input, 4 output) auto partition_multiplier = RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits() - JoinHashTable::INITIAL_RADIX_BITS); - auto thread_memory = 2 * blocks_per_vector * partition_multiplier * Storage::BLOCK_SIZE; + auto thread_memory = 2 * blocks_per_vector * partition_multiplier * block_size; auto repartition_threads = MaxValue(sink.temporary_memory_state->GetReservation() / thread_memory, 1); if (repartition_threads < local_hts.size()) { @@ -433,7 +533,7 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { partition_tasks.reserve(local_hts.size()); for (auto &local_ht : local_hts) { partition_tasks.push_back( - make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht)); + make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht, op)); } SetTasks(std::move(partition_tasks)); } @@ -445,40 +545,88 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { const auto num_partitions = RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits()); vector partition_sizes(num_partitions, 0); vector partition_counts(num_partitions, 0); - idx_t max_partition_size; - idx_t max_partition_count; - sink.hash_table->GetTotalSize(partition_sizes, partition_counts, max_partition_size, max_partition_count); - sink.temporary_memory_state->SetMinimumReservation(max_partition_size + - JoinHashTable::PointerTableSize(max_partition_count)); + sink.total_size = sink.hash_table->GetTotalSize(partition_sizes, partition_counts, sink.max_partition_size, + sink.max_partition_count); + const auto probe_side_requirement = + GetPartitioningSpaceRequirement(sink.context, op.types, sink.hash_table->GetRadixBits(), sink.num_threads); + + sink.temporary_memory_state->SetMinimumReservation(sink.max_partition_size + + JoinHashTable::PointerTableSize(sink.max_partition_count) + + probe_side_requirement); + sink.temporary_memory_state->UpdateReservation(executor.context); + sink.hash_table->PrepareExternalFinalize(sink.temporary_memory_state->GetReservation()); sink.ScheduleFinalize(*pipeline, *this); } }; +void JoinFilterPushdownInfo::PushFilters(JoinFilterGlobalState &gstate, const PhysicalOperator &op) const { + // finalize the min/max aggregates + vector min_max_types; + for (auto &aggr_expr : min_max_aggregates) { + min_max_types.push_back(aggr_expr->return_type); + } + DataChunk final_min_max; + final_min_max.Initialize(Allocator::DefaultAllocator(), min_max_types); + + gstate.global_aggregate_state->Finalize(final_min_max); + + // create a filter for each of the aggregates + for (idx_t filter_idx = 0; filter_idx < filters.size(); filter_idx++) { + auto &filter = filters[filter_idx]; + auto filter_col_idx = filter.probe_column_index.column_index; + auto min_idx = filter_idx * 2; + auto max_idx = min_idx + 1; + + auto min_val = final_min_max.data[min_idx].GetValue(0); + auto max_val = final_min_max.data[max_idx].GetValue(0); + if (min_val.IsNull() || max_val.IsNull()) { + // min/max is NULL + // this can happen in case all values in the RHS column are NULL, but they are still pushed into the hash + // table e.g. because they are part of a RIGHT join + continue; + } + if (Value::NotDistinctFrom(min_val, max_val)) { + // min = max - generate an equality filter + auto constant_filter = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(min_val)); + dynamic_filters->PushFilter(op, filter_col_idx, std::move(constant_filter)); + } else { + // min != max - generate a range filter + auto greater_equals = + make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, std::move(min_val)); + dynamic_filters->PushFilter(op, filter_col_idx, std::move(greater_equals)); + auto less_equals = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, std::move(max_val)); + dynamic_filters->PushFilter(op, filter_col_idx, std::move(less_equals)); + } + // not null filter + dynamic_filters->PushFilter(op, filter_col_idx, make_uniq()); + } +} + SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, OperatorSinkFinalizeInput &input) const { auto &sink = input.global_state.Cast(); auto &ht = *sink.hash_table; - idx_t max_partition_size; - idx_t max_partition_count; - auto const total_size = ht.GetTotalSize(sink.local_hash_tables, max_partition_size, max_partition_count); - sink.temporary_memory_state->SetRemainingSize(context, total_size); - - sink.external = sink.temporary_memory_state->GetReservation() < total_size; + sink.temporary_memory_state->UpdateReservation(context); + sink.external = sink.temporary_memory_state->GetReservation() < sink.total_size; if (sink.external) { - const auto max_partition_ht_size = max_partition_size + JoinHashTable::PointerTableSize(max_partition_count); // External Hash Join sink.perfect_join_executor.reset(); + + const auto max_partition_ht_size = + sink.max_partition_size + JoinHashTable::PointerTableSize(sink.max_partition_count); if (max_partition_ht_size > sink.temporary_memory_state->GetReservation()) { // We have to repartition - ht.SetRepartitionRadixBits(sink.local_hash_tables, sink.temporary_memory_state->GetReservation(), - max_partition_size, max_partition_count); - auto new_event = make_shared_ptr(pipeline, sink, sink.local_hash_tables); + ht.SetRepartitionRadixBits(sink.temporary_memory_state->GetReservation(), sink.max_partition_size, + sink.max_partition_count); + auto new_event = make_shared_ptr(pipeline, *this, sink, sink.local_hash_tables); event.InsertEvent(std::move(new_event)); } else { - // No repartitioning! - sink.temporary_memory_state->SetMinimumReservation(max_partition_ht_size); + // No repartitioning! We do need some space for partitioning the probe-side, though + const auto probe_side_requirement = + GetPartitioningSpaceRequirement(context, children[0]->types, ht.GetRadixBits(), sink.num_threads); + sink.temporary_memory_state->SetMinimumReservation(max_partition_ht_size + probe_side_requirement); for (auto &local_ht : sink.local_hash_tables) { ht.Merge(*local_ht); } @@ -488,13 +636,17 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl } sink.finalized = true; return SinkFinalizeType::READY; - } else { - // In-memory Hash Join - for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); - } - sink.local_hash_tables.clear(); - ht.Unpartition(); + } + + // In-memory Hash Join + for (auto &local_ht : sink.local_hash_tables) { + ht.Merge(*local_ht); + } + sink.local_hash_tables.clear(); + ht.Unpartition(); + + if (filter_pushdown && ht.Count() > 0) { + filter_pushdown->PushFilters(*sink.global_filter_state, *this); } // check for possible perfect hash table @@ -521,31 +673,32 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl //===--------------------------------------------------------------------===// class HashJoinOperatorState : public CachingOperatorState { public: - explicit HashJoinOperatorState(ClientContext &context) : probe_executor(context), initialized(false) { + explicit HashJoinOperatorState(ClientContext &context, HashJoinGlobalSinkState &sink) + : probe_executor(context), scan_structure(*sink.hash_table, join_key_state) { } DataChunk join_keys; TupleDataChunkState join_key_state; ExpressionExecutor probe_executor; - unique_ptr scan_structure; + JoinHashTable::ScanStructure scan_structure; unique_ptr perfect_hash_join_state; - bool initialized; JoinHashTable::ProbeSpillLocalAppendState spill_state; + JoinHashTable::ProbeState probe_state; //! Chunk to sink data into for external join DataChunk spill_chunk; public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, probe_executor, "probe_executor", 0); + context.thread.profiler.Flush(op); } }; unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { auto &allocator = BufferAllocator::Get(context.client); auto &sink = sink_state->Cast(); - auto state = make_uniq(context.client); + auto state = make_uniq(context.client, sink); if (sink.perfect_join_executor) { state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); } else { @@ -570,17 +723,12 @@ OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, D_ASSERT(sink.finalized); D_ASSERT(!sink.scanned_data); - // some initialization for external hash join - if (sink.external && !state.initialized) { - if (!sink.probe_spill) { - sink.InitializeProbeSpill(); + if (sink.hash_table->Count() == 0) { + if (EmptyResultIfRHSIsEmpty()) { + return OperatorResultType::FINISHED; } - state.spill_state = sink.probe_spill->RegisterThread(); - state.initialized = true; - } - - if (sink.hash_table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return OperatorResultType::FINISHED; + ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, input, chunk); + return OperatorResultType::NEED_MORE_INPUT; } if (sink.perfect_join_executor) { @@ -588,34 +736,35 @@ OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, return sink.perfect_join_executor->ProbePerfectHashTable(context, input, chunk, *state.perfect_hash_join_state); } - if (state.scan_structure) { - // still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) - state.scan_structure->Next(state.join_keys, input, chunk); - if (!state.scan_structure->PointersExhausted() || chunk.size() > 0) { - return OperatorResultType::HAVE_MORE_OUTPUT; + if (sink.external && !state.initialized) { + // some initialization for external hash join + if (!sink.probe_spill) { + sink.InitializeProbeSpill(); } - state.scan_structure = nullptr; - return OperatorResultType::NEED_MORE_INPUT; + state.spill_state = sink.probe_spill->RegisterThread(); + state.initialized = true; } - // probe the HT - if (sink.hash_table->Count() == 0) { - ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } + if (state.scan_structure.is_null) { + // probe the HT, start by resolving the join keys for the left chunk + state.join_keys.Reset(); + state.probe_executor.Execute(input, state.join_keys); - // resolve the join keys for the left chunk - state.join_keys.Reset(); - state.probe_executor.Execute(input, state.join_keys); + // perform the actual probe + if (sink.external) { + sink.hash_table->ProbeAndSpill(state.scan_structure, state.join_keys, state.join_key_state, + state.probe_state, input, *sink.probe_spill, state.spill_state, + state.spill_chunk); + } else { + sink.hash_table->Probe(state.scan_structure, state.join_keys, state.join_key_state, state.probe_state); + } + } + state.scan_structure.Next(state.join_keys, input, chunk); - // perform the actual probe - if (sink.external) { - state.scan_structure = sink.hash_table->ProbeAndSpill(state.join_keys, state.join_key_state, input, - *sink.probe_spill, state.spill_state, state.spill_chunk); - } else { - state.scan_structure = sink.hash_table->Probe(state.join_keys, state.join_key_state); + if (state.scan_structure.PointersExhausted() && chunk.size() == 0) { + state.scan_structure.is_null = true; + return OperatorResultType::NEED_MORE_INPUT; } - state.scan_structure->Next(state.join_keys, input, chunk); return OperatorResultType::HAVE_MORE_OUTPUT; } @@ -628,7 +777,7 @@ class HashJoinLocalSourceState; class HashJoinGlobalSourceState : public GlobalSourceState { public: - HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context); + HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context); //! Initialize this source state using the info in the sink void Initialize(HashJoinGlobalSinkState &sink); @@ -661,13 +810,12 @@ class HashJoinGlobalSourceState : public GlobalSourceState { //! For synchronizing the external hash join atomic global_stage; - mutex lock; //! For HT build synchronization - idx_t build_chunk_idx; + idx_t build_chunk_idx = DConstants::INVALID_INDEX; idx_t build_chunk_count; idx_t build_chunk_done; - idx_t build_chunks_per_thread; + idx_t build_chunks_per_thread = DConstants::INVALID_INDEX; //! For probe synchronization atomic probe_chunk_count; @@ -678,22 +826,22 @@ class HashJoinGlobalSourceState : public GlobalSourceState { idx_t parallel_scan_chunk_count; //! For full/outer synchronization - idx_t full_outer_chunk_idx; + idx_t full_outer_chunk_idx = DConstants::INVALID_INDEX; atomic full_outer_chunk_count; atomic full_outer_chunk_done; - idx_t full_outer_chunks_per_thread; + idx_t full_outer_chunks_per_thread = DConstants::INVALID_INDEX; vector blocked_tasks; }; class HashJoinLocalSourceState : public LocalSourceState { public: - HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator); + HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, Allocator &allocator); //! Do the work this thread has been assigned void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); //! Whether this thread has finished the work it has been assigned - bool TaskFinished(); + bool TaskFinished() const; //! Build, probe and scan for external hash join void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate); void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); @@ -706,8 +854,8 @@ class HashJoinLocalSourceState : public LocalSourceState { Vector addresses; //! Chunks assigned to this thread for building the pointer table - idx_t build_chunk_idx_from; - idx_t build_chunk_idx_to; + idx_t build_chunk_idx_from = DConstants::INVALID_INDEX; + idx_t build_chunk_idx_to = DConstants::INVALID_INDEX; //! Local scan state for probe spill ColumnDataConsumerScanState probe_local_scan; @@ -716,16 +864,18 @@ class HashJoinLocalSourceState : public LocalSourceState { DataChunk join_keys; DataChunk payload; TupleDataChunkState join_key_state; + //! Column indices to easily reference the join keys/payload columns in probe_chunk vector join_key_indices; vector payload_indices; //! Scan structure for the external probe - unique_ptr scan_structure; - bool empty_ht_probe_in_progress; + JoinHashTable::ScanStructure scan_structure; + JoinHashTable::ProbeState probe_state; + bool empty_ht_probe_in_progress = false; //! Chunks assigned to this thread for a full/outer scan - idx_t full_outer_chunk_idx_from; - idx_t full_outer_chunk_idx_to; + idx_t full_outer_chunk_idx_from = DConstants::INVALID_INDEX; + idx_t full_outer_chunk_idx_to = DConstants::INVALID_INDEX; unique_ptr full_outer_scan_state; }; @@ -735,17 +885,18 @@ unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientConte unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(*this, BufferAllocator::Get(context.client)); + return make_uniq(*this, sink_state->Cast(), + BufferAllocator::Get(context.client)); } -HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context) +HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context) : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality), parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { } void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { - lock_guard init_lock(lock); + auto guard = Lock(); if (global_stage != HashJoinSourceStage::INIT) { // Another thread initialized return; @@ -797,12 +948,12 @@ void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { auto &ht = *sink.hash_table; // Update remaining size - sink.temporary_memory_state->SetRemainingSize(sink.context, ht.GetRemainingSize()); + sink.temporary_memory_state->SetRemainingSizeAndUpdateReservation(sink.context, ht.GetRemainingSize()); // Try to put the next partitions in the block collection of the HT if (!sink.external || !ht.PrepareExternalFinalize(sink.temporary_memory_state->GetReservation())) { global_stage = HashJoinSourceStage::DONE; - sink.temporary_memory_state->SetRemainingSize(sink.context, 0); + sink.temporary_memory_state->SetZero(); return; } @@ -816,8 +967,7 @@ void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { build_chunk_count = data_collection.ChunkCount(); build_chunk_done = 0; - auto num_threads = NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); - build_chunks_per_thread = MaxValue((build_chunk_count + num_threads - 1) / num_threads, 1); + build_chunks_per_thread = MaxValue((build_chunk_count + sink.num_threads - 1) / sink.num_threads, 1); ht.InitializePointerTable(); @@ -847,8 +997,8 @@ void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { full_outer_chunk_count = data_collection.ChunkCount(); full_outer_chunk_done = 0; - auto num_threads = NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); - full_outer_chunks_per_thread = MaxValue((full_outer_chunk_count + num_threads - 1) / num_threads, 1); + full_outer_chunks_per_thread = + MaxValue((full_outer_chunk_count + sink.num_threads - 1) / sink.num_threads, 1); global_stage = HashJoinSourceStage::SCAN_HT; } @@ -856,7 +1006,7 @@ void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) { D_ASSERT(lstate.TaskFinished()); - lock_guard guard(lock); + auto guard = Lock(); switch (global_stage.load()) { case HashJoinSourceStage::BUILD: if (build_chunk_idx != build_chunk_count) { @@ -892,12 +1042,13 @@ bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJo return false; } -HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator) - : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER) { +HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, + Allocator &allocator) + : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER), + scan_structure(*sink.hash_table, join_key_state) { auto &chunk_state = probe_local_scan.current_chunk_state; chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - auto &sink = op.sink_state->Cast(); probe_chunk.Initialize(allocator, sink.probe_types); join_keys.Initialize(allocator, op.condition_types); payload.Initialize(allocator, op.children[0]->types); @@ -930,13 +1081,13 @@ void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJo } } -bool HashJoinLocalSourceState::TaskFinished() { +bool HashJoinLocalSourceState::TaskFinished() const { switch (local_stage) { case HashJoinSourceStage::INIT: case HashJoinSourceStage::BUILD: return true; case HashJoinSourceStage::PROBE: - return scan_structure == nullptr && !empty_ht_probe_in_progress; + return scan_structure.is_null && !empty_ht_probe_in_progress; case HashJoinSourceStage::SCAN_HT: return full_outer_scan_state == nullptr; default: @@ -950,7 +1101,7 @@ void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, Hash auto &ht = *sink.hash_table; ht.Finalize(build_chunk_idx_from, build_chunk_idx_to, true); - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from; } @@ -958,20 +1109,20 @@ void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, Hash DataChunk &chunk) { D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized); - if (scan_structure) { + if (!scan_structure.is_null) { // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) - scan_structure->Next(join_keys, payload, chunk); - if (chunk.size() != 0 || !scan_structure->PointersExhausted()) { + scan_structure.Next(join_keys, payload, chunk); + if (chunk.size() != 0 || !scan_structure.PointersExhausted()) { return; } } - if (scan_structure || empty_ht_probe_in_progress) { + if (!scan_structure.is_null || empty_ht_probe_in_progress) { // Previous probe is done - scan_structure = nullptr; + scan_structure.is_null = true; empty_ht_probe_in_progress = false; sink.probe_spill->consumer->FinishChunk(probe_local_scan); - lock_guard lock(gstate.lock); + auto guard = gstate.Lock(); gstate.probe_chunk_done++; return; } @@ -991,8 +1142,8 @@ void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, Hash } // Perform the probe - scan_structure = sink.hash_table->Probe(join_keys, join_key_state, precomputed_hashes); - scan_structure->Next(join_keys, payload, chunk); + sink.hash_table->Probe(scan_structure, join_keys, join_key_state, probe_state, precomputed_hashes); + scan_structure.Next(join_keys, payload, chunk); } void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, @@ -1007,7 +1158,7 @@ void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, Has if (chunk.size() == 0) { full_outer_scan_state = nullptr; - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from; } } @@ -1020,10 +1171,11 @@ SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk sink.scanned_data = true; if (!sink.external && !PropagatesBuildSide(join_type)) { - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); if (gstate.global_stage != HashJoinSourceStage::DONE) { gstate.global_stage = HashJoinSourceStage::DONE; - sink.temporary_memory_state->SetRemainingSize(context.client, 0); + sink.hash_table->Reset(); + sink.temporary_memory_state->SetZero(); } return SourceResultType::FINISHED; } @@ -1038,15 +1190,11 @@ SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { lstate.ExecuteTask(sink, gstate, chunk); } else { - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); if (gstate.TryPrepareNextStage(sink) || gstate.global_stage == HashJoinSourceStage::DONE) { - for (auto &state : gstate.blocked_tasks) { - state.Callback(); - } - gstate.blocked_tasks.clear(); + gstate.UnblockTasks(guard); } else { - gstate.blocked_tasks.push_back(input.interrupt_state); - return SourceResultType::BLOCKED; + return gstate.BlockSource(guard, input.interrupt_state); } } } @@ -1060,23 +1208,24 @@ double PhysicalHashJoin::GetProgress(ClientContext &context, GlobalSourceState & if (!sink.external) { if (PropagatesBuildSide(join_type)) { - return double(gstate.full_outer_chunk_done) / double(gstate.full_outer_chunk_count) * 100.0; + return static_cast(gstate.full_outer_chunk_done) / + static_cast(gstate.full_outer_chunk_count) * 100.0; } return 100.0; } - double num_partitions = RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits()); - double partition_start = sink.hash_table->GetPartitionStart(); - double partition_end = sink.hash_table->GetPartitionEnd(); + auto num_partitions = static_cast(RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits())); + auto partition_start = static_cast(sink.hash_table->GetPartitionStart()); + auto partition_end = static_cast(sink.hash_table->GetPartitionEnd()); // This many partitions are fully done - auto progress = partition_start / double(num_partitions); + auto progress = partition_start / num_partitions; - double probe_chunk_done = gstate.probe_chunk_done; - double probe_chunk_count = gstate.probe_chunk_count; + auto probe_chunk_done = static_cast(gstate.probe_chunk_done); + auto probe_chunk_count = static_cast(gstate.probe_chunk_count); if (probe_chunk_count != 0) { // Progress of the current round of probing, weighed by the number of partitions - auto probe_progress = double(probe_chunk_done) / double(probe_chunk_count); + auto probe_progress = probe_chunk_done / probe_chunk_count; // Add it to the progress, weighed by the number of partitions in the current round progress += (partition_end - partition_start) / num_partitions * probe_progress; } @@ -1084,20 +1233,28 @@ double PhysicalHashJoin::GetProgress(ClientContext &context, GlobalSourceState & return progress * 100.0; } -string PhysicalHashJoin::ParamsToString() const { - string result = EnumUtil::ToString(join_type) + "\n"; - for (auto &it : conditions) { - string op = ExpressionTypeToOperator(it.comparison); - result += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; +InsertionOrderPreservingMap PhysicalHashJoin::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Join Type"] = EnumUtil::ToString(join_type); + + string condition_info; + for (idx_t i = 0; i < conditions.size(); i++) { + auto &join_condition = conditions[i]; + if (i > 0) { + condition_info += "\n"; + } + condition_info += + StringUtil::Format("%s %s %s", join_condition.left->GetName(), + ExpressionTypeToOperator(join_condition.comparison), join_condition.right->GetName()); } - result += "\n[INFOSEPARATOR]\n"; + result["Conditions"] = condition_info; + if (perfect_join_statistics.is_build_small) { // perfect hash join - result += "Build Min: " + perfect_join_statistics.build_min.ToString() + "\n"; - result += "Build Max: " + perfect_join_statistics.build_max.ToString() + "\n"; - result += "\n[INFOSEPARATOR]\n"; + result["Build Min"] = perfect_join_statistics.build_min.ToString(); + result["Build Max"] = perfect_join_statistics.build_max.ToString(); } - result += StringUtil::Format("EC: %llu\n", estimated_cardinality); + SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index 8c89195b..143b1f20 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -11,8 +11,8 @@ #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" - -#include +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/thread.hpp" namespace duckdb { @@ -87,17 +87,17 @@ class IEJoinGlobalState : public GlobalSinkState { lhs_layout.Initialize(op.children[0]->types); vector lhs_order; lhs_order.emplace_back(op.lhs_orders[0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout); + tables[0] = make_uniq(context, lhs_order, lhs_layout, op); RowLayout rhs_layout; rhs_layout.Initialize(op.children[1]->types); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout); + tables[1] = make_uniq(context, rhs_order, rhs_layout, op); } - IEJoinGlobalState(IEJoinGlobalState &prev) - : GlobalSinkState(prev), tables(std::move(prev.tables)), child(prev.child + 1) { + IEJoinGlobalState(IEJoinGlobalState &prev) : tables(std::move(prev.tables)), child(prev.child + 1) { + state = prev.state; } void Sink(DataChunk &input, IEJoinLocalState &lstate) { @@ -147,7 +147,7 @@ SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, Operato gstate.tables[gstate.child]->Combine(lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.table.executor, gstate.child ? "rhs_executor" : "lhs_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); return SinkCombineResultType::FINISHED; @@ -389,7 +389,7 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte vector orders; orders.emplace_back(order1.type, order1.null_order, std::move(ref)); - l1 = make_uniq(context, orders, payload_layout); + l1 = make_uniq(context, orders, payload_layout, op); // LHS has positive rids ExpressionExecutor l_executor(context); @@ -432,7 +432,7 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte ExpressionExecutor executor(context); executor.AddExpression(*orders[0].expression); - l2 = make_uniq(context, orders, payload_layout); + l2 = make_uniq(context, orders, payload_layout, op); for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); } @@ -790,20 +790,20 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re class IEJoinGlobalSourceState : public GlobalSourceState { public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op) - : op(op), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), right_outers(0), - next_right(0) { + explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) + : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), + right_outers(0), next_right(0) { } - void Initialize(IEJoinGlobalState &sink_state) { - lock_guard initializing(lock); + void Initialize() { + auto guard = Lock(); if (initialized) { return; } // Compute the starting row for reach block // (In theory these are all the same size, but you never know...) - auto &left_table = *sink_state.tables[0]; + auto &left_table = *gsink.tables[0]; const auto left_blocks = left_table.BlockCount(); idx_t left_base = 0; @@ -812,7 +812,7 @@ class IEJoinGlobalSourceState : public GlobalSourceState { left_base += left_table.BlockSize(lhs); } - auto &right_table = *sink_state.tables[1]; + auto &right_table = *gsink.tables[1]; const auto right_blocks = right_table.BlockCount(); idx_t right_base = 0; for (size_t rhs = 0; rhs < right_blocks; ++rhs) { @@ -840,9 +840,9 @@ class IEJoinGlobalSourceState : public GlobalSourceState { return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); } - void GetNextPair(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { - auto &left_table = *gstate.tables[0]; - auto &right_table = *gstate.tables[1]; + void GetNextPair(ClientContext &client, IEJoinLocalSourceState &lstate) { + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; const auto left_blocks = left_table.BlockCount(); const auto right_blocks = right_table.BlockCount(); @@ -905,35 +905,53 @@ class IEJoinGlobalSourceState : public GlobalSourceState { } } - void PairCompleted(ClientContext &client, IEJoinGlobalState &gstate, IEJoinLocalSourceState &lstate) { + void PairCompleted(ClientContext &client, IEJoinLocalSourceState &lstate) { lstate.joiner.reset(); ++completed; - GetNextPair(client, gstate, lstate); + GetNextPair(client, lstate); + } + + double GetProgress() const { + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + const auto left_blocks = left_table.BlockCount(); + const auto right_blocks = right_table.BlockCount(); + const auto pair_count = left_blocks * right_blocks; + + const auto count = pair_count + left_outers + right_outers; + + const auto l = MinValue(next_left.load(), left_outers.load()); + const auto r = MinValue(next_right.load(), right_outers.load()); + const auto returned = completed.load() + l + r; + + return count ? (double(returned) / double(count)) : -1; } const PhysicalIEJoin &op; + IEJoinGlobalState &gsink; - mutex lock; bool initialized; // Join queue state - std::atomic next_pair; - std::atomic completed; + atomic next_pair; + atomic completed; // Block base row number vector left_bases; vector right_bases; // Outer joins - idx_t left_outers; - std::atomic next_left; + atomic left_outers; + atomic next_left; - idx_t right_outers; - std::atomic next_right; + atomic right_outers; + atomic next_right; }; unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); + auto &gsink = sink_state->Cast(); + return make_uniq(*this, gsink); } unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, @@ -941,16 +959,21 @@ unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContex return make_uniq(context.client, *this); } +double PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { + auto &gsource = gsource_p.Cast(); + return gsource.GetProgress(); +} + SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, OperatorSourceInput &input) const { auto &ie_sink = sink_state->Cast(); auto &ie_gstate = input.global_state.Cast(); auto &ie_lstate = input.local_state.Cast(); - ie_gstate.Initialize(ie_sink); + ie_gstate.Initialize(); if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + ie_gstate.GetNextPair(context.client, ie_lstate); } // Process INNER results @@ -961,7 +984,7 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r return SourceResultType::HAVE_MORE_OUTPUT; } - ie_gstate.PairCompleted(context.client, ie_sink, ie_lstate); + ie_gstate.PairCompleted(context.client, ie_lstate); } // Process LEFT OUTER results @@ -969,7 +992,7 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r while (ie_lstate.left_matches) { const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); if (!count) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + ie_gstate.GetNextPair(context.client, ie_lstate); continue; } auto &chunk = ie_lstate.unprojected; @@ -994,7 +1017,7 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r while (ie_lstate.right_matches) { const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); if (!count) { - ie_gstate.GetNextPair(context.client, ie_sink, ie_lstate); + ie_gstate.GetNextPair(context.client, ie_lstate); continue; } diff --git a/src/duckdb/src/execution/operator/join/physical_join.cpp b/src/duckdb/src/execution/operator/join/physical_join.cpp index 6d9813c9..bb7011a3 100644 --- a/src/duckdb/src/execution/operator/join/physical_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_join.cpp @@ -42,15 +42,29 @@ void PhysicalJoin::BuildJoinPipelines(Pipeline ¤t, MetaPipeline &meta_pipe meta_pipeline.GetPipelines(pipelines_so_far, false); auto &last_pipeline = *pipelines_so_far.back(); + vector> dependencies; + optional_ptr last_child_ptr; if (build_rhs) { // on the RHS (build side), we construct a child MetaPipeline with this operator as its sink - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op); + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op, MetaPipelineType::JOIN_BUILD); child_meta_pipeline.Build(*op.children[1]); + if (op.children[1]->CanSaturateThreads(current.GetClientContext())) { + // if the build side can saturate all available threads, + // we don't just make the LHS pipeline depend on the RHS, but recursively all LHS children too. + // this prevents breadth-first plan evaluation + child_meta_pipeline.GetPipelines(dependencies, false); + last_child_ptr = meta_pipeline.GetLastChild(); + } } // continue building the current pipeline on the LHS (probe side) op.children[0]->BuildPipelines(current, meta_pipeline); + if (last_child_ptr) { + // the pointer was set, set up the dependencies + meta_pipeline.AddRecursiveDependencies(dependencies, *last_child_ptr); + } + switch (op.type) { case PhysicalOperatorType::POSITIONAL_JOIN: // Positional joins are always outer @@ -63,13 +77,7 @@ void PhysicalJoin::BuildJoinPipelines(Pipeline ¤t, MetaPipeline &meta_pipe } // Join can become a source operator if it's RIGHT/OUTER, or if the hash join goes out-of-core - bool add_child_pipeline = false; - auto &join_op = op.Cast(); - if (join_op.IsSource()) { - add_child_pipeline = true; - } - - if (add_child_pipeline) { + if (op.Cast().IsSource()) { meta_pipeline.CreateChildPipeline(current, op, last_pipeline); } } diff --git a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp index 1d6972e0..49f259ab 100644 --- a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp @@ -1,20 +1,19 @@ #include "duckdb/execution/operator/join/physical_left_delim_join.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_join.hpp" #include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" namespace duckdb { PhysicalLeftDelimJoin::PhysicalLeftDelimJoin(vector types, unique_ptr original_join, vector> delim_scans, - idx_t estimated_cardinality) + idx_t estimated_cardinality, optional_idx delim_idx) : PhysicalDelimJoin(PhysicalOperatorType::LEFT_DELIM_JOIN, std::move(types), std::move(original_join), - std::move(delim_scans), estimated_cardinality) { + std::move(delim_scans), estimated_cardinality, delim_idx) { D_ASSERT(join->children.size() == 2); // now for the original join // we take its left child, this is the side that we will duplicate eliminate @@ -24,6 +23,9 @@ PhysicalLeftDelimJoin::PhysicalLeftDelimJoin(vector types, unique_p // the actual chunk collection to scan will be created in the LeftDelimJoinGlobalState auto cached_chunk_scan = make_uniq( children[0]->GetTypes(), PhysicalOperatorType::COLUMN_DATA_SCAN, estimated_cardinality, nullptr); + if (delim_idx.IsValid()) { + cached_chunk_scan->cte_index = delim_idx.GetIndex(); + } join->children[0] = std::move(cached_chunk_scan); } @@ -101,6 +103,10 @@ SinkCombineResultType PhysicalLeftDelimJoin::Combine(ExecutionContext &context, return SinkCombineResultType::FINISHED; } +void PhysicalLeftDelimJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { + distinct->PrepareFinalize(context, *distinct->sink_state); +} + SinkFinalizeType PhysicalLeftDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { // finalize the distinct HT diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp index 60e4eb2f..022337e5 100644 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -200,12 +200,9 @@ SinkResultType PhysicalNestedLoopJoin::Sink(ExecutionContext &context, DataChunk SinkCombineResultType PhysicalNestedLoopJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this, state.rhs_executor, "rhs_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); - return SinkCombineResultType::FINISHED; } @@ -266,7 +263,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, lhs_executor, "lhs_executor", 0); + context.thread.profiler.Flush(op); } }; diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 287c7971..d7b30423 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -76,7 +76,7 @@ class MergeJoinGlobalState : public GlobalSinkState { rhs_layout.Initialize(op.children[1]->types); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout); + table = make_uniq(context, rhs_order, rhs_layout, op); } inline idx_t Count() const { @@ -125,7 +125,7 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont gstate.table->Combine(lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.table.executor, "rhs_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); return SinkCombineResultType::FINISHED; @@ -250,7 +250,7 @@ class PiecewiseMergeJoinState : public CachingOperatorState { void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { if (lhs_local_table) { - context.thread.profiler.Flush(op, lhs_local_table->executor, "lhs_executor", 0); + context.thread.profiler.Flush(op); } } }; diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index 398242bd..6c7deb1e 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -59,9 +59,9 @@ void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState } PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout) - : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0), - memory_per_thread(0) { + RowLayout &payload_layout, const PhysicalOperator &op_p) + : op(op_p), global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), + count(0), memory_per_thread(0) { D_ASSERT(orders.size() == 1); // Set external (can be forced with the PRAGMA) @@ -77,7 +77,7 @@ void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { } void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { - found_match = make_unsafe_uniq_array(Count()); + found_match = make_unsafe_uniq_array_uninitialized(Count()); memset(found_match.get(), 0, sizeof(bool) * Count()); } @@ -91,7 +91,7 @@ class RangeJoinMergeTask : public ExecutorTask { public: RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context, std::move(event_p)), context(context), table(table) { + : ExecutorTask(context, std::move(event_p), table.op), context(context), table(table) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { diff --git a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp index aad86786..60aaeaca 100644 --- a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp @@ -12,9 +12,9 @@ namespace duckdb { PhysicalRightDelimJoin::PhysicalRightDelimJoin(vector types, unique_ptr original_join, vector> delim_scans, - idx_t estimated_cardinality) + idx_t estimated_cardinality, optional_idx delim_idx) : PhysicalDelimJoin(PhysicalOperatorType::RIGHT_DELIM_JOIN, std::move(types), std::move(original_join), - std::move(delim_scans), estimated_cardinality) { + std::move(delim_scans), estimated_cardinality, delim_idx) { D_ASSERT(join->children.size() == 2); // now for the original join // we take its right child, this is the side that we will duplicate eliminate @@ -79,6 +79,11 @@ SinkCombineResultType PhysicalRightDelimJoin::Combine(ExecutionContext &context, return SinkCombineResultType::FINISHED; } +void PhysicalRightDelimJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { + join->PrepareFinalize(context, *join->sink_state); + distinct->PrepareFinalize(context, *distinct->sink_state); +} + SinkFinalizeType PhysicalRightDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { D_ASSERT(join); diff --git a/src/duckdb/src/execution/operator/order/physical_order.cpp b/src/duckdb/src/execution/operator/order/physical_order.cpp index c916c7da..e0bb0c94 100644 --- a/src/duckdb/src/execution/operator/order/physical_order.cpp +++ b/src/duckdb/src/execution/operator/order/physical_order.cpp @@ -22,9 +22,10 @@ PhysicalOrder::PhysicalOrder(vector types, vector class OrderGlobalSinkState : public GlobalSinkState { public: OrderGlobalSinkState(BufferManager &buffer_manager, const PhysicalOrder &order, RowLayout &payload_layout) - : global_sort_state(buffer_manager, order.orders, payload_layout) { + : order(order), global_sort_state(buffer_manager, order.orders, payload_layout) { } + const PhysicalOrder ℴ //! Global sort state GlobalSortState global_sort_state; //! Memory usage per thread @@ -112,8 +113,9 @@ SinkCombineResultType PhysicalOrder::Combine(ExecutionContext &context, Operator class PhysicalOrderMergeTask : public ExecutorTask { public: - PhysicalOrderMergeTask(shared_ptr event_p, ClientContext &context, OrderGlobalSinkState &state) - : ExecutorTask(context, std::move(event_p)), context(context), state(state) { + PhysicalOrderMergeTask(shared_ptr event_p, ClientContext &context, OrderGlobalSinkState &state, + const PhysicalOperator &op_p) + : ExecutorTask(context, std::move(event_p), op_p), context(context), state(state) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { @@ -132,11 +134,12 @@ class PhysicalOrderMergeTask : public ExecutorTask { class OrderMergeEvent : public BasePipelineEvent { public: - OrderMergeEvent(OrderGlobalSinkState &gstate_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p) { + OrderMergeEvent(OrderGlobalSinkState &gstate_p, Pipeline &pipeline_p, const PhysicalOperator &op_p) + : BasePipelineEvent(pipeline_p), gstate(gstate_p), op(op_p) { } OrderGlobalSinkState &gstate; + const PhysicalOperator &op; public: void Schedule() override { @@ -148,7 +151,7 @@ class OrderMergeEvent : public BasePipelineEvent { vector> merge_tasks; for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.push_back(make_uniq(shared_from_this(), context, gstate)); + merge_tasks.push_back(make_uniq(shared_from_this(), context, gstate, op)); } SetTasks(std::move(merge_tasks)); } @@ -187,7 +190,7 @@ SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, Clien void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { // Initialize global sort state for a round of merging state.global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(state, pipeline); + auto new_event = make_shared_ptr(state, pipeline, state.order); event.InsertEvent(std::move(new_event)); } @@ -267,15 +270,17 @@ idx_t PhysicalOrder::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, return lstate.batch_index; } -string PhysicalOrder::ParamsToString() const { - string result = "ORDERS:\n"; +InsertionOrderPreservingMap PhysicalOrder::ParamsToString() const { + InsertionOrderPreservingMap result; + string orders_info; for (idx_t i = 0; i < orders.size(); i++) { if (i > 0) { - result += "\n"; + orders_info += "\n"; } - result += orders[i].expression->ToString() + " "; - result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; + orders_info += orders[i].expression->ToString() + " "; + orders_info += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; } + result["__order_by__"] = orders_info; return result; } diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp index 2d5e9b63..aa686878 100644 --- a/src/duckdb/src/execution/operator/order/physical_top_n.cpp +++ b/src/duckdb/src/execution/operator/order/physical_top_n.cpp @@ -492,19 +492,22 @@ SourceResultType PhysicalTopN::GetData(ExecutionContext &context, DataChunk &chu return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } -string PhysicalTopN::ParamsToString() const { - string result; - result += "Top " + to_string(limit); +InsertionOrderPreservingMap PhysicalTopN::ParamsToString() const { + InsertionOrderPreservingMap result; + result["Top"] = to_string(limit); if (offset > 0) { - result += "\n"; - result += "Offset " + to_string(offset); + result["Offset"] = to_string(offset); } - result += "\n[INFOSEPARATOR]"; + + string orders_info; for (idx_t i = 0; i < orders.size(); i++) { - result += "\n"; - result += orders[i].expression->ToString() + " "; - result += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; + if (i > 0) { + orders_info += "\n"; + } + orders_info += orders[i].expression->ToString() + " "; + orders_info += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; } + result["Order By"] = orders_info; return result; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 17f4e0e4..d1fcd7b7 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -1,14 +1,16 @@ #include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" -#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/types/batched_data_collection.hpp" + #include "duckdb/common/allocator.hpp" #include "duckdb/common/queue.hpp" -#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/operator/persistent/batch_memory_manager.hpp" #include "duckdb/execution/operator/persistent/batch_task_manager.hpp" +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/executor_task.hpp" +#include "duckdb/storage/buffer_manager.hpp" + #include namespace duckdb { @@ -158,15 +160,14 @@ SinkResultType PhysicalBatchCopyToFile::Sink(ExecutionContext &context, DataChun auto batch_index = state.partition_info.batch_index.GetIndex(); if (state.current_task == FixedBatchCopyState::PROCESSING_TASKS) { ExecuteTasks(context.client, gstate); - FlushBatchData(context.client, gstate, memory_manager.GetMinimumBatchIndex()); + FlushBatchData(context.client, gstate); if (!memory_manager.IsMinimumBatchIndex(batch_index) && memory_manager.OutOfMemory(batch_index)) { - lock_guard l(memory_manager.GetBlockedTaskLock()); + auto guard = memory_manager.Lock(); if (!memory_manager.IsMinimumBatchIndex(batch_index)) { // no tasks to process, we are not the minimum batch index and we have no memory available to buffer // block the task for now - memory_manager.BlockTask(input.interrupt_state); - return SinkResultType::BLOCKED; + return memory_manager.BlockSink(guard, input.interrupt_state); } } state.current_task = FixedBatchCopyState::SINKING_DATA; @@ -232,7 +233,7 @@ class ProcessRemainingBatchesTask : public ExecutorTask { TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { while (op.ExecuteTask(context, gstate)) { - op.FlushBatchData(context, gstate, 0); + op.FlushBatchData(context, gstate); } event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; @@ -279,8 +280,8 @@ SinkFinalizeType PhysicalBatchCopyToFile::FinalFlush(ClientContext &context, Glo if (gstate.task_manager.TaskCount() != 0) { throw InternalException("Unexecuted tasks are remaining in PhysicalFixedBatchCopy::FinalFlush!?"); } - auto min_batch_index = idx_t(NumericLimits::Maximum()); - FlushBatchData(context, gstate_p, min_batch_index); + + FlushBatchData(context, gstate_p); if (gstate.scheduled_batch_index != gstate.flushed_batch_index) { throw InternalException("Not all batches were flushed to disk - incomplete file?"); } @@ -323,7 +324,7 @@ class RepartitionedFlushTask : public BatchCopyTask { } void Execute(const PhysicalBatchCopyToFile &op, ClientContext &context, GlobalSinkState &gstate_p) override { - op.FlushBatchData(context, gstate_p, 0); + op.FlushBatchData(context, gstate_p); } }; @@ -475,7 +476,7 @@ void PhysicalBatchCopyToFile::RepartitionBatches(ClientContext &context, GlobalS } } -void PhysicalBatchCopyToFile::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index) const { +void PhysicalBatchCopyToFile::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p) const { auto &gstate = gstate_p.Cast(); auto &memory_manager = gstate.memory_manager; @@ -554,14 +555,18 @@ void PhysicalBatchCopyToFile::AddLocalBatch(ClientContext &context, GlobalSinkSt // attempt to repartition to our desired batch size RepartitionBatches(context, gstate, min_batch_index); // unblock tasks so they can help process batches (if any are blocked) - auto any_unblocked = memory_manager.UnblockTasks(); + bool any_unblocked; + { + auto guard = memory_manager.Lock(); + any_unblocked = memory_manager.UnblockTasks(guard); + } // if any threads were unblocked they can pick up execution of the tasks // otherwise we will execute a task and flush here if (!any_unblocked) { //! Execute a single repartition task ExecuteTask(context, gstate); //! Flush batch data to disk (if any is ready) - FlushBatchData(context, gstate, memory_manager.GetMinimumBatchIndex()); + FlushBatchData(context, gstate); } } @@ -605,7 +610,20 @@ SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, Dat auto &g = sink_state->Cast(); chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + switch (return_type) { + case CopyFunctionReturnType::CHANGED_ROWS: + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + break; + case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: { + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + auto fp = use_tmp_file ? PhysicalCopyToFile::GetNonTmpFile(context.client, file_path) : file_path; + chunk.SetValue(1, 0, Value::LIST(LogicalType::VARCHAR, {fp})); + break; + } + default: + throw NotImplementedException("Unknown CopyFunctionReturnType"); + } + return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index a65d1cca..8bad7cbb 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -1,15 +1,16 @@ #include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" + +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/execution/operator/persistent/batch_memory_manager.hpp" #include "duckdb/execution/operator/persistent/batch_task_manager.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/row_group_collection.hpp" +#include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_io_manager.hpp" -#include "duckdb/transaction/local_storage.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/storage/table/append_state.hpp" -#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/transaction/local_storage.hpp" namespace duckdb { @@ -438,7 +439,11 @@ SinkNextBatchType PhysicalBatchInsert::NextBatch(ExecutionContext &context, Oper gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), std::move(lstate.current_collection), lstate.writer); - auto any_unblocked = memory_manager.UnblockTasks(); + bool any_unblocked; + { + auto guard = memory_manager.Lock(); + any_unblocked = memory_manager.UnblockTasks(guard); + } if (!any_unblocked) { ExecuteTasks(context.client, gstate, lstate); } @@ -447,7 +452,8 @@ SinkNextBatchType PhysicalBatchInsert::NextBatch(ExecutionContext &context, Oper lstate.current_index = batch_index; // unblock any blocked tasks - memory_manager.UnblockTasks(); + auto guard = memory_manager.Lock(); + memory_manager.UnblockTasks(guard); return SinkNextBatchType::READY; } @@ -475,12 +481,11 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &c // execute tasks while we wait (if any are available) ExecuteTasks(context.client, gstate, lstate); - lock_guard l(memory_manager.GetBlockedTaskLock()); + auto guard = memory_manager.Lock(); if (!memory_manager.IsMinimumBatchIndex(batch_index)) { // we are not the minimum batch index and we have no memory available to buffer - block the task for // now - memory_manager.BlockTask(input.interrupt_state); - return SinkResultType::BLOCKED; + return memory_manager.BlockSink(guard, input.interrupt_state); } } } @@ -518,7 +523,7 @@ SinkCombineResultType PhysicalBatchInsert::Combine(ExecutionContext &context, Op auto &lstate = input.local_state.Cast(); auto &memory_manager = gstate.memory_manager; auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); memory_manager.UpdateMinBatchIndex(lstate.partition_info.min_batch_index.GetIndex()); @@ -537,7 +542,8 @@ SinkCombineResultType PhysicalBatchInsert::Combine(ExecutionContext &context, Op } // unblock any blocked tasks - memory_manager.UnblockTasks(); + auto guard = memory_manager.Lock(); + memory_manager.UnblockTasks(guard); return SinkCombineResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp index e66881be..353c091c 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp @@ -8,6 +8,7 @@ #include "duckdb/parser/parsed_data/create_table_info.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/parser/parsed_data/create_view_info.hpp" +#include "duckdb/parser/parsed_data/create_index_info.hpp" namespace duckdb { @@ -49,7 +50,10 @@ SourceResultType PhysicalCopyDatabase::GetData(ExecutionContext &context, DataCh catalog.CreateTable(context.client, *bound_info); break; } - case CatalogType::INDEX_ENTRY: + case CatalogType::INDEX_ENTRY: { + catalog.CreateIndex(context.client, create_info->Cast()); + break; + } default: throw NotImplementedException("Entry type %s not supported in PhysicalCopyDatabase", CatalogTypeToString(create_info->type)); diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp index 2280b770..fece3ef0 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -5,15 +5,17 @@ #include "duckdb/common/hive_partitioning.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/storage_lock.hpp" #include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/operator/logical_copy_to_file.hpp" + #include namespace duckdb { struct PartitionWriteInfo { unique_ptr global_state; + idx_t active_writes = 0; }; struct VectorOfValuesHashFunction { @@ -45,8 +47,9 @@ using vector_of_value_map_t = unordered_map, T, VectorOfValuesHash class CopyToFunctionGlobalState : public GlobalSinkState { public: - explicit CopyToFunctionGlobalState(unique_ptr global_state) + explicit CopyToFunctionGlobalState(ClientContext &context, unique_ptr global_state) : rows_copied(0), last_file_offset(0), global_state(std::move(global_state)) { + max_open_files = ClientConfig::GetConfig(context).partitioned_write_max_open_files; } StorageLock lock; atomic rows_copied; @@ -56,6 +59,10 @@ class CopyToFunctionGlobalState : public GlobalSinkState { unordered_set created_directories; //! shared state for HivePartitionedColumnData shared_ptr partition_state; + //! File names + vector file_names; + //! Max open files + idx_t max_open_files; void CreateDir(const string &dir_path, FileSystem &fs) { if (created_directories.find(dir_path) != created_directories.end()) { @@ -74,13 +81,21 @@ class CopyToFunctionGlobalState : public GlobalSinkState { for (idx_t i = 0; i < cols.size(); i++) { const auto &partition_col_name = names[cols[i]]; const auto &partition_value = values[i]; - string p_dir = partition_col_name + "=" + partition_value.ToString(); + string p_dir; + p_dir += HivePartitioning::Escape(partition_col_name); + p_dir += "="; + p_dir += HivePartitioning::Escape(partition_value.ToString()); path = fs.JoinPath(path, p_dir); CreateDir(path, fs); } return path; } + void AddFileName(const StorageLockKey &l, const string &file_name) { + D_ASSERT(l.GetType() == StorageLockType::EXCLUSIVE); + file_names.emplace_back(file_name); + } + void FinalizePartition(ClientContext &context, const PhysicalCopyToFile &op, PartitionWriteInfo &info) { if (!info.global_state) { // already finalized @@ -100,30 +115,69 @@ class CopyToFunctionGlobalState : public GlobalSinkState { PartitionWriteInfo &GetPartitionWriteInfo(ExecutionContext &context, const PhysicalCopyToFile &op, const vector &values) { - auto l = lock.GetExclusiveLock(); + auto global_lock = lock.GetExclusiveLock(); // check if we have already started writing this partition - auto entry = active_partitioned_writes.find(values); - if (entry != active_partitioned_writes.end()) { + auto active_write_entry = active_partitioned_writes.find(values); + if (active_write_entry != active_partitioned_writes.end()) { // we have - continue writing in this partition - return *entry->second; + active_write_entry->second->active_writes++; + return *active_write_entry->second; + } + // check if we need to close any writers before we can continue + if (active_partitioned_writes.size() >= max_open_files) { + // we need to! try to close writers + for (auto &entry : active_partitioned_writes) { + if (entry.second->active_writes == 0) { + // we can evict this entry - evict the partition + FinalizePartition(context.client, op, *entry.second); + ++previous_partitions[entry.first]; + active_partitioned_writes.erase(entry.first); + break; + } + } + } + idx_t offset = 0; + auto prev_offset = previous_partitions.find(values); + if (prev_offset != previous_partitions.end()) { + offset = prev_offset->second; } auto &fs = FileSystem::GetFileSystem(context.client); // Create a writer for the current file auto trimmed_path = op.GetTrimmedPath(context.client); string hive_path = GetOrCreateDirectory(op.partition_columns, op.names, values, trimmed_path, fs); - string full_path(op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, 0)); + string full_path(op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset)); + if (op.overwrite_mode == CopyOverwriteMode::COPY_APPEND) { + // when appending, we first check if the file exists + while (fs.FileExists(full_path)) { + // file already exists - re-generate name + if (!op.filename_pattern.HasUUID()) { + throw InternalException("CopyOverwriteMode::COPY_APPEND without {uuid} - and file exists"); + } + full_path = op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset); + } + } + if (op.return_type == CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST) { + AddFileName(*global_lock, full_path); + } // initialize writes auto info = make_uniq(); info->global_state = op.function.copy_to_initialize_global(context.client, *op.bind_data, full_path); auto &result = *info; + info->active_writes = 1; // store in active write map active_partitioned_writes.insert(make_pair(values, std::move(info))); return result; } + void FinishPartitionWrite(PartitionWriteInfo &info) { + auto global_lock = lock.GetExclusiveLock(); + info.active_writes--; + } + private: //! The active writes per partition (for partitioned write) vector_of_value_map_t> active_partitioned_writes; + vector_of_value_map_t previous_partitions; }; string PhysicalCopyToFile::GetTrimmedPath(ClientContext &context) const { @@ -175,6 +229,22 @@ class CopyToFunctionLocalState : public LocalSinkState { append_count = 0; } + void SetDataWithoutPartitions(DataChunk &chunk, const DataChunk &source, const vector &col_types, + const vector &part_cols) { + D_ASSERT(source.ColumnCount() == col_types.size()); + auto types = LogicalCopyToFile::GetTypesWithoutPartitions(col_types, part_cols, false); + chunk.InitializeEmpty(types); + set part_col_set(part_cols.begin(), part_cols.end()); + idx_t new_col_id = 0; + for (idx_t col_idx = 0; col_idx < source.ColumnCount(); col_idx++) { + if (part_col_set.find(col_idx) == part_col_set.end()) { + chunk.data[new_col_id].Reference(source.data[col_idx]); + new_col_id++; + } + } + chunk.SetCardinality(source.size()); + } + void FlushPartitions(ExecutionContext &context, const PhysicalCopyToFile &op, CopyToFunctionGlobalState &g) { if (!part_buffer) { return; @@ -194,22 +264,33 @@ class CopyToFunctionLocalState : public LocalSinkState { auto local_copy_state = op.function.copy_to_initialize_local(context, *op.bind_data); // push the chunks into the write state for (auto &chunk : partitions[i]->Chunks()) { - op.function.copy_to_sink(context, *op.bind_data, *info.global_state, *local_copy_state, chunk); + if (op.write_partition_columns) { + op.function.copy_to_sink(context, *op.bind_data, *info.global_state, *local_copy_state, chunk); + } else { + DataChunk filtered_chunk; + SetDataWithoutPartitions(filtered_chunk, chunk, op.expected_types, op.partition_columns); + op.function.copy_to_sink(context, *op.bind_data, *info.global_state, *local_copy_state, + filtered_chunk); + } } op.function.copy_to_combine(context, *op.bind_data, *info.global_state, *local_copy_state); local_copy_state.reset(); partitions[i].reset(); + g.FinishPartitionWrite(info); } ResetAppendState(); } }; -unique_ptr PhysicalCopyToFile::CreateFileState(ClientContext &context, - GlobalSinkState &sink) const { +unique_ptr PhysicalCopyToFile::CreateFileState(ClientContext &context, GlobalSinkState &sink, + StorageLockKey &global_lock) const { auto &g = sink.Cast(); idx_t this_file_offset = g.last_file_offset++; auto &fs = FileSystem::GetFileSystem(context); string output_path(filename_pattern.CreateFilename(fs, file_path, file_extension, this_file_offset)); + if (return_type == CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST) { + g.AddFileName(global_lock, output_path); + } return function.copy_to_initialize_global(context, *bind_data, output_path); } @@ -222,14 +303,12 @@ unique_ptr PhysicalCopyToFile::GetLocalSinkState(ExecutionContex return std::move(state); } auto res = make_uniq(function.copy_to_initialize_local(context, *bind_data)); - if (per_thread_output) { - res->global_state = CreateFileState(context.client, *sink_state); - } return std::move(res); } void CheckDirectory(FileSystem &fs, const string &file_path, CopyOverwriteMode overwrite_mode) { - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE) { + if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE || + overwrite_mode == CopyOverwriteMode::COPY_APPEND) { // with overwrite or ignore we fully ignore the presence of any files instead of erasing them return; } @@ -265,8 +344,7 @@ void CheckDirectory(FileSystem &fs, const string &file_path, CopyOverwriteMode o } unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext &context) const { - - if (partition_output || per_thread_output || file_size_bytes.IsValid()) { + if (partition_output || per_thread_output || rotate) { auto &fs = FileSystem::GetFileSystem(context); if (fs.FileExists(file_path)) { // the target file exists AND is a file (not a directory) @@ -291,9 +369,10 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext CheckDirectory(fs, file_path, overwrite_mode); } - auto state = make_uniq(nullptr); - if (!per_thread_output && file_size_bytes.IsValid()) { - state->global_state = CreateFileState(context, *state); + auto state = make_uniq(context, nullptr); + if (!per_thread_output && rotate) { + auto global_lock = state->lock.GetExclusiveLock(); + state->global_state = CreateFileState(context, *state, *global_lock); } if (partition_output) { @@ -303,7 +382,15 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext return std::move(state); } - return make_uniq(function.copy_to_initialize_global(context, *bind_data, file_path)); + auto state = make_uniq( + context, function.copy_to_initialize_global(context, *bind_data, file_path)); + if (use_tmp_file) { + auto global_lock = state->lock.GetExclusiveLock(); + state->AddFileName(*global_lock, file_path); + } else { + state->file_names.emplace_back(file_path); + } + return std::move(state); } //===--------------------------------------------------------------------===// @@ -311,6 +398,15 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext //===--------------------------------------------------------------------===// void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_file_path) { auto &fs = FileSystem::GetFileSystem(context); + auto file_path = GetNonTmpFile(context, tmp_file_path); + if (fs.FileExists(file_path)) { + fs.RemoveFile(file_path); + } + fs.MoveFile(tmp_file_path, file_path); +} + +string PhysicalCopyToFile::GetNonTmpFile(ClientContext &context, const string &tmp_file_path) { + auto &fs = FileSystem::GetFileSystem(context); auto path = StringUtil::GetFilePath(tmp_file_path); auto base = StringUtil::GetFileName(tmp_file_path); @@ -320,11 +416,7 @@ void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_f base = base.substr(4); } - auto file_path = fs.JoinPath(path, base); - if (fs.FileExists(file_path)) { - fs.RemoveFile(file_path); - } - fs.MoveFile(tmp_file_path, file_path); + return fs.JoinPath(path, base); } PhysicalCopyToFile::PhysicalCopyToFile(vector types, CopyFunction function_p, @@ -337,42 +429,46 @@ SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &ch auto &g = input.global_state.Cast(); auto &l = input.local_state.Cast(); + g.rows_copied += chunk.size(); + if (partition_output) { l.AppendToPartition(context, *this, g, chunk); return SinkResultType::NEED_MORE_INPUT; } - g.rows_copied += chunk.size(); - if (per_thread_output) { auto &gstate = l.global_state; - function.copy_to_sink(context, *bind_data, *gstate, *l.local_state, chunk); - - if (file_size_bytes.IsValid() && function.file_size_bytes(*gstate) > file_size_bytes.GetIndex()) { + if (!gstate) { + // Lazily create file state here to prevent creating empty files + auto global_lock = g.lock.GetExclusiveLock(); + gstate = CreateFileState(context.client, *sink_state, *global_lock); + } else if (rotate && function.rotate_next_file(*gstate, *bind_data, file_size_bytes)) { function.copy_to_finalize(context.client, *bind_data, *gstate); - gstate = CreateFileState(context.client, *sink_state); + auto global_lock = g.lock.GetExclusiveLock(); + gstate = CreateFileState(context.client, *sink_state, *global_lock); } + function.copy_to_sink(context, *bind_data, *gstate, *l.local_state, chunk); return SinkResultType::NEED_MORE_INPUT; } - if (!file_size_bytes.IsValid()) { + if (!file_size_bytes.IsValid() && !rotate) { function.copy_to_sink(context, *bind_data, *g.global_state, *l.local_state, chunk); return SinkResultType::NEED_MORE_INPUT; } - // FILE_SIZE_BYTES is set, but threads write to the same file, synchronize using lock + // FILE_SIZE_BYTES/rotate is set, but threads write to the same file, synchronize using lock auto &gstate = g.global_state; - auto lock = g.lock.GetExclusiveLock(); - if (function.file_size_bytes(*gstate) > file_size_bytes.GetIndex()) { + auto global_lock = g.lock.GetExclusiveLock(); + if (rotate && function.rotate_next_file(*gstate, *bind_data, file_size_bytes)) { auto owned_gstate = std::move(gstate); - gstate = CreateFileState(context.client, *sink_state); - lock.reset(); + gstate = CreateFileState(context.client, *sink_state, *global_lock); + global_lock.reset(); function.copy_to_finalize(context.client, *bind_data, *owned_gstate); } else { - lock.reset(); + global_lock.reset(); } - lock = g.lock.GetSharedLock(); + global_lock = g.lock.GetSharedLock(); function.copy_to_sink(context, *bind_data, *gstate, *l.local_state, chunk); return SinkResultType::NEED_MORE_INPUT; @@ -387,11 +483,13 @@ SinkCombineResultType PhysicalCopyToFile::Combine(ExecutionContext &context, Ope l.FlushPartitions(context, *this, g); } else if (function.copy_to_combine) { if (per_thread_output) { - // For PER_THREAD_OUTPUT, we can combine/finalize immediately - function.copy_to_combine(context, *bind_data, *l.global_state, *l.local_state); - function.copy_to_finalize(context.client, *bind_data, *l.global_state); - } else if (file_size_bytes.IsValid()) { - // File in global state may change with FILE_SIZE_BYTES, need to grab lock + // For PER_THREAD_OUTPUT, we can combine/finalize immediately (if there is a gstate) + if (l.global_state) { + function.copy_to_combine(context, *bind_data, *l.global_state, *l.local_state); + function.copy_to_finalize(context.client, *bind_data, *l.global_state); + } + } else if (rotate) { + // File in global state may change with FILE_SIZE_BYTES/rotate, need to grab lock auto lock = g.lock.GetSharedLock(); function.copy_to_combine(context, *bind_data, *g.global_state, *l.local_state); } else { @@ -421,6 +519,7 @@ SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, D_ASSERT(!per_thread_output); D_ASSERT(!partition_output); D_ASSERT(!file_size_bytes.IsValid()); + D_ASSERT(!rotate); MoveTmpFile(context, file_path); } } @@ -436,7 +535,17 @@ SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChun auto &g = sink_state->Cast(); chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + switch (return_type) { + case CopyFunctionReturnType::CHANGED_ROWS: + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + break; + case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); + chunk.SetValue(1, 0, Value::LIST(LogicalType::VARCHAR, g.file_names)); + break; + default: + throw NotImplementedException("Unknown CopyFunctionReturnType"); + } return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 25ab4819..626f28b8 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -446,10 +446,18 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, gstate.initialized = true; } - if (return_chunk) { + if (action_type != OnConflictAction::NOTHING && return_chunk) { + // If the action is UPDATE or REPLACE, we will always create either an APPEND or an INSERT + // for NOTHING we don't create either an APPEND or an INSERT for the tuple + // so it should not be added to the RETURNING chunk gstate.return_collection.Append(lstate.insert_chunk); } idx_t updated_tuples = OnConflictHandling(table, context, lstate); + if (action_type == OnConflictAction::NOTHING && return_chunk) { + // Because we didn't add to the RETURNING chunk yet + // we add the tuples that did not get filtered out now + gstate.return_collection.Append(lstate.insert_chunk); + } gstate.insert_count += lstate.insert_chunk.size(); gstate.insert_count += updated_tuples; storage.LocalAppend(gstate.append_state, table, context.client, lstate.insert_chunk, true); @@ -488,7 +496,7 @@ SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, Operato auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, lstate.default_executor, "default_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); if (!parallel || !lstate.local_collection) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp index 8dc66343..f314eb12 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -166,11 +166,9 @@ unique_ptr PhysicalUpdate::GetLocalSinkState(ExecutionContext &c } SinkCombineResultType PhysicalUpdate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this, state.default_executor, "default_executor", 1); + context.thread.profiler.Flush(*this); client_profiler.Flush(context.thread.profiler); - return SinkCombineResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp index 6f333df9..d3def4c0 100644 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -20,8 +20,8 @@ PhysicalPivot::PhysicalPivot(vector types_p, unique_ptrCast(); // for each aggregate, initialize an empty aggregate state and finalize it immediately - auto state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(state.get()); + auto state = make_unsafe_uniq_array(aggr.function.state_size(aggr.function)); + aggr.function.initialize(aggr.function, state.get()); Vector state_vector(Value::POINTER(CastPointerToValue(state.get()))); Vector result_vector(aggr_expr->return_type); AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); diff --git a/src/duckdb/src/execution/operator/projection/physical_projection.cpp b/src/duckdb/src/execution/operator/projection/physical_projection.cpp index e6d915e3..5d6dcb13 100644 --- a/src/duckdb/src/execution/operator/projection/physical_projection.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_projection.cpp @@ -15,7 +15,7 @@ class ProjectionState : public OperatorState { public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op, executor, "projection", 0); + context.thread.profiler.Flush(op); } }; @@ -69,12 +69,19 @@ PhysicalProjection::CreateJoinProjection(vector proj_types, const v return make_uniq(std::move(proj_types), std::move(proj_selects), estimated_cardinality); } -string PhysicalProjection::ParamsToString() const { - string extra_info; - for (auto &expr : select_list) { - extra_info += expr->GetName() + "\n"; +InsertionOrderPreservingMap PhysicalProjection::ParamsToString() const { + InsertionOrderPreservingMap result; + string projections; + for (idx_t i = 0; i < select_list.size(); i++) { + if (i > 0) { + projections += "\n"; + } + auto &expr = select_list[i]; + projections += expr->GetName(); } - return extra_info; + result["__projections__"] = projections; + SetEstimatedCardinality(result, estimated_cardinality); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp index d32094f7..9fa89e51 100644 --- a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp @@ -38,7 +38,16 @@ unique_ptr PhysicalTableInOutFunction::GetOperatorState(Execution result->local_state = function.init_local(context, input, gstate.global_state.get()); } if (!projected_input.empty()) { - result->input_chunk.Initialize(context.client, children[0]->types); + vector input_types; + auto &child_types = children[0]->types; + idx_t input_length = child_types.size() - projected_input.size(); + for (idx_t k = 0; k < input_length; k++) { + input_types.push_back(child_types[k]); + } + for (idx_t k = 0; k < projected_input.size(); k++) { + D_ASSERT(projected_input[k] >= input_length); + } + result->input_chunk.Initialize(context.client, input_types); } return std::move(result); } @@ -71,9 +80,8 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context } // we are processing a new row: fetch the data for the current row state.input_chunk.Reset(); - D_ASSERT(input.ColumnCount() == state.input_chunk.ColumnCount()); // set up the input data to the table in-out function - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + for (idx_t col_idx = 0; col_idx < state.input_chunk.ColumnCount(); col_idx++) { ConstantVector::Reference(state.input_chunk.data[col_idx], input.data[col_idx], state.row_index, 1); } state.input_chunk.SetCardinality(1); @@ -100,6 +108,17 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context return OperatorResultType::HAVE_MORE_OUTPUT; } +InsertionOrderPreservingMap PhysicalTableInOutFunction::ParamsToString() const { + InsertionOrderPreservingMap result; + if (function.to_string) { + result["__text__"] = function.to_string(bind_data.get()); + } else { + result["Name"] = function.name; + } + SetEstimatedCardinality(result, estimated_cardinality); + return result; +} + OperatorFinalizeResultType PhysicalTableInOutFunction::FinalExecute(ExecutionContext &context, DataChunk &chunk, GlobalOperatorState &gstate_p, OperatorState &state_p) const { diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp index 2d5421a7..b964d91f 100644 --- a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp @@ -88,11 +88,17 @@ static void UnnestNull(idx_t start, idx_t end, Vector &result) { for (idx_t i = start; i < end; i++) { validity.SetInvalid(i); } - if (result.GetType().InternalType() == PhysicalType::STRUCT) { - auto &struct_children = StructVector::GetEntries(result); + + const auto &logical_type = result.GetType(); + if (logical_type.InternalType() == PhysicalType::STRUCT) { + const auto &struct_children = StructVector::GetEntries(result); for (auto &child : struct_children) { UnnestNull(start, end, *child); } + } else if (logical_type.InternalType() == PhysicalType::ARRAY) { + auto &array_child = ArrayVector::GetEntry(result); + auto array_size = ArrayType::GetSize(logical_type); + UnnestNull(start * array_size, end * array_size, array_child); } } @@ -205,7 +211,17 @@ static void UnnestVector(UnifiedVectorFormat &child_vector_data, Vector &child_v break; } case PhysicalType::ARRAY: { - throw NotImplementedException("ARRAY type not supported for UNNEST."); + auto array_size = ArrayType::GetSize(child_vector.GetType()); + auto &source_array = ArrayVector::GetEntry(child_vector); + auto &target_array = ArrayVector::GetEntry(result); + + UnnestValidity(child_vector_data, start, end, result); + + UnifiedVectorFormat child_array_data; + source_array.ToUnifiedFormat(list_size * array_size, child_array_data); + UnnestVector(child_array_data, source_array, list_size * array_size, start * array_size, end * array_size, + target_array); + break; } default: throw InternalException("Unimplemented type for UNNEST."); diff --git a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp index 1db1d5f3..e864f3a2 100644 --- a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp @@ -11,7 +11,8 @@ namespace duckdb { PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, idx_t estimated_cardinality, optionally_owned_ptr collection_p) - : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(std::move(collection_p)) { + : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(std::move(collection_p)), + cte_index(DConstants::INVALID_INDEX) { } PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, @@ -19,32 +20,42 @@ PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, Physic : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(nullptr), cte_index(cte_index) { } -class PhysicalColumnDataScanState : public GlobalSourceState { +class PhysicalColumnDataGlobalScanState : public GlobalSourceState { public: - explicit PhysicalColumnDataScanState() : initialized(false) { + explicit PhysicalColumnDataGlobalScanState(const ColumnDataCollection &collection) + : max_threads(MaxValue(collection.ChunkCount(), 1)) { + collection.InitializeScan(global_scan_state); } - //! The current position in the scan - ColumnDataScanState scan_state; - bool initialized; + idx_t MaxThreads() override { + return max_threads; + } + +public: + ColumnDataParallelScanState global_scan_state; + + const idx_t max_threads; +}; + +class PhysicalColumnDataLocalScanState : public LocalSourceState { +public: + ColumnDataLocalScanState local_scan_state; }; unique_ptr PhysicalColumnDataScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); + return make_uniq(*collection); +} + +unique_ptr PhysicalColumnDataScan::GetLocalSourceState(ExecutionContext &, + GlobalSourceState &) const { + return make_uniq(); } SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - if (collection->Count() == 0) { - return SourceResultType::FINISHED; - } - if (!state.initialized) { - collection->InitializeScan(state.scan_state); - state.initialized = true; - } - collection->Scan(state.scan_state, chunk); - + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + collection->Scan(gstate.global_scan_state, lstate.local_scan_state, chunk); return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } @@ -96,19 +107,23 @@ void PhysicalColumnDataScan::BuildPipelines(Pipeline ¤t, MetaPipeline &met state.SetPipelineSource(current, *this); } -string PhysicalColumnDataScan::ParamsToString() const { - string result = ""; +InsertionOrderPreservingMap PhysicalColumnDataScan::ParamsToString() const { + InsertionOrderPreservingMap result; switch (type) { + case PhysicalOperatorType::DELIM_SCAN: + if (delim_index.IsValid()) { + result["Delim Index"] = StringUtil::Format("%llu", delim_index.GetIndex()); + } + break; case PhysicalOperatorType::CTE_SCAN: case PhysicalOperatorType::RECURSIVE_CTE_SCAN: { - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", cte_index); + result["CTE Index"] = StringUtil::Format("%llu", cte_index); break; } default: break; } - + SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index ba64b293..27ba3982 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -13,18 +13,22 @@ PhysicalTableScan::PhysicalTableScan(vector types, TableFunction fu unique_ptr bind_data_p, vector returned_types_p, vector column_ids_p, vector projection_ids_p, vector names_p, unique_ptr table_filters_p, - idx_t estimated_cardinality, ExtraOperatorInfo extra_info) + idx_t estimated_cardinality, ExtraOperatorInfo extra_info, + vector parameters_p) : PhysicalOperator(PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), - table_filters(std::move(table_filters_p)), extra_info(extra_info) { + table_filters(std::move(table_filters_p)), extra_info(extra_info), parameters(std::move(parameters_p)) { } class TableScanGlobalSourceState : public GlobalSourceState { public: TableScanGlobalSourceState(ClientContext &context, const PhysicalTableScan &op) { + if (op.dynamic_filters && op.dynamic_filters->HasFilters()) { + table_filters = op.dynamic_filters->GetFinalTableFilters(op, op.table_filters.get()); + } if (op.function.init_global) { - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); + TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, GetTableFilters(op)); global_state = op.function.init_global(context, input); if (global_state) { max_threads = global_state->MaxThreads(); @@ -32,11 +36,30 @@ class TableScanGlobalSourceState : public GlobalSourceState { } else { max_threads = 1; } + if (op.function.in_out_function) { + // this is an in-out function, we need to setup the input chunk + vector input_types; + for (auto ¶m : op.parameters) { + input_types.push_back(param.type()); + } + input_chunk.Initialize(context, input_types); + for (idx_t c = 0; c < op.parameters.size(); c++) { + input_chunk.data[c].SetValue(0, op.parameters[c]); + } + input_chunk.SetCardinality(1); + } } idx_t max_threads = 0; unique_ptr global_state; + bool in_out_final = false; + DataChunk input_chunk; + //! Combined table filters, if we have dynamic filters + unique_ptr table_filters; + optional_ptr GetTableFilters(const PhysicalTableScan &op) const { + return table_filters ? table_filters.get() : op.table_filters.get(); + } idx_t MaxThreads() override { return max_threads; } @@ -47,7 +70,8 @@ class TableScanLocalSourceState : public LocalSourceState { TableScanLocalSourceState(ExecutionContext &context, TableScanGlobalSourceState &gstate, const PhysicalTableScan &op) { if (op.function.init_local) { - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, op.table_filters.get()); + TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, + gstate.GetTableFilters(op)); local_state = op.function.init_local(context, input, gstate.global_state.get()); } } @@ -71,7 +95,18 @@ SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk auto &state = input.local_state.Cast(); TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - function.function(context.client, data, chunk); + if (function.function) { + function.function(context.client, data, chunk); + } else { + if (gstate.in_out_final) { + function.in_out_function_final(context, data, chunk); + } + function.in_out_function(context, data, gstate.input_chunk, chunk); + if (chunk.size() == 0 && function.in_out_function_final) { + function.in_out_function_final(context, data, chunk); + gstate.in_out_final = true; + } + } return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } @@ -99,53 +134,65 @@ string PhysicalTableScan::GetName() const { return StringUtil::Upper(function.name + " " + function.extra_info); } -string PhysicalTableScan::ParamsToString() const { - string result; +InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { + InsertionOrderPreservingMap result; if (function.to_string) { - result = function.to_string(bind_data.get()); - result += "\n[INFOSEPARATOR]\n"; + result["__text__"] = function.to_string(bind_data.get()); + } else { + result["Function"] = StringUtil::Upper(function.name); } if (function.projection_pushdown) { if (function.filter_prune) { + string projections; for (idx_t i = 0; i < projection_ids.size(); i++) { const auto &column_id = column_ids[projection_ids[i]]; if (column_id < names.size()) { if (i > 0) { - result += "\n"; + projections += "\n"; } - result += names[column_id]; + projections += names[column_id]; } } + result["Projections"] = projections; } else { + string projections; for (idx_t i = 0; i < column_ids.size(); i++) { const auto &column_id = column_ids[i]; if (column_id < names.size()) { if (i > 0) { - result += "\n"; + projections += "\n"; } - result += names[column_id]; + projections += names[column_id]; } } + result["Projections"] = projections; } } if (function.filter_pushdown && table_filters) { - result += "\n[INFOSEPARATOR]\n"; - result += "Filters: "; + string filters_info; + bool first_item = true; for (auto &f : table_filters->filters) { auto &column_index = f.first; auto &filter = f.second; if (column_index < names.size()) { - result += filter->ToString(names[column_ids[column_index]]); - result += "\n"; + if (!first_item) { + filters_info += "\n"; + } + first_item = false; + filters_info += filter->ToString(names[column_ids[column_index]]); } } + result["Filters"] = filters_info; } if (!extra_info.file_filters.empty()) { - result += "\n[INFOSEPARATOR]\n"; - result += "File Filters: " + extra_info.file_filters; + result["File Filters"] = extra_info.file_filters; + if (extra_info.filtered_files.IsValid() && extra_info.total_files.IsValid()) { + result["Scanning Files"] = StringUtil::Format("%llu/%llu", extra_info.filtered_files.GetIndex(), + extra_info.total_files.GetIndex()); + } } - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("EC: %llu", estimated_cardinality); + + SetEstimatedCardinality(result, estimated_cardinality); return result; } @@ -166,4 +213,13 @@ bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const { return true; } +bool PhysicalTableScan::ParallelSource() const { + if (!function.function) { + // table in-out functions cannot be executed in parallel as part of a PhysicalTableScan + // since they have only a single input row + return false; + } + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp index 2c2b76a0..69af179f 100644 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -10,48 +10,6 @@ namespace duckdb { -//===--------------------------------------------------------------------===// -// Helper -//===--------------------------------------------------------------------===// - -void ParseOptions(const unique_ptr &info, AccessMode &access_mode, string &db_type, - string &unrecognized_option) { - - for (auto &entry : info->options) { - - if (entry.first == "readonly" || entry.first == "read_only") { - auto read_only = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); - if (read_only) { - access_mode = AccessMode::READ_ONLY; - } else { - access_mode = AccessMode::READ_WRITE; - } - continue; - } - - if (entry.first == "readwrite" || entry.first == "read_write") { - auto read_only = !BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); - if (read_only) { - access_mode = AccessMode::READ_ONLY; - } else { - access_mode = AccessMode::READ_WRITE; - } - continue; - } - - if (entry.first == "type") { - // extract the database type - db_type = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); - continue; - } - - // we allow unrecognized options - if (unrecognized_option.empty()) { - unrecognized_option = entry.first; - } - } -} - //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// @@ -59,16 +17,13 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c OperatorSourceInput &input) const { // parse the options auto &config = DBConfig::GetConfig(context.client); - AccessMode access_mode = config.options.access_mode; - string db_type; - string unrecognized_option; - ParseOptions(info, access_mode, db_type, unrecognized_option); + AttachOptions options(info, config.options.access_mode); // get the name and path of the database auto &name = info->name; auto &path = info->path; - if (db_type.empty()) { - DBPathAndType::ExtractExtensionPrefix(path, db_type); + if (options.db_type.empty()) { + DBPathAndType::ExtractExtensionPrefix(path, options.db_type); } if (name.empty()) { auto &fs = FileSystem::GetFileSystem(context.client); @@ -82,12 +37,12 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c auto existing_db = db_manager.GetDatabase(context.client, name); if (existing_db) { - if ((existing_db->IsReadOnly() && access_mode == AccessMode::READ_WRITE) || - (!existing_db->IsReadOnly() && access_mode == AccessMode::READ_ONLY)) { + if ((existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_WRITE) || + (!existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_ONLY)) { auto existing_mode = existing_db->IsReadOnly() ? AccessMode::READ_ONLY : AccessMode::READ_WRITE; auto existing_mode_str = EnumUtil::ToString(existing_mode); - auto attached_mode = EnumUtil::ToString(access_mode); + auto attached_mode = EnumUtil::ToString(options.access_mode); throw BinderException("Database \"%s\" is already attached in %s mode, cannot re-attach in %s mode", name, existing_mode_str, attached_mode); } @@ -96,10 +51,27 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c } } - // get the database type and attach the database - db_manager.GetDatabaseType(context.client, db_type, *info, config, unrecognized_option); - auto attached_db = db_manager.AttachDatabase(context.client, *info, db_type, access_mode); - attached_db->Initialize(); + string extension = ""; + if (FileSystem::IsRemoteFile(path, extension)) { + if (!ExtensionHelper::TryAutoLoadExtension(context.client, extension)) { + throw MissingExtensionException("Attaching path '%s' requires extension '%s' to be loaded", path, + extension); + } + if (options.access_mode == AccessMode::AUTOMATIC) { + // Attaching of remote files gets bumped to READ_ONLY + // This is due to the fact that on most (all?) remote files writes to DB are not available + // and having this raised later is not super helpful + options.access_mode = AccessMode::READ_ONLY; + } + } + + // Get the database type and attach the database. + db_manager.GetDatabaseType(context.client, *info, config, options); + auto attached_db = db_manager.AttachDatabase(context.client, *info, options); + + //! Initialize the database. + const auto block_alloc_size = info->GetBlockAllocSize(); + attached_db->Initialize(block_alloc_size); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp index 90eceacf..7eceeecb 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp @@ -21,7 +21,7 @@ PhysicalCreateARTIndex::PhysicalCreateARTIndex(LogicalOperator &op, TableCatalog table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), sorted(sorted) { - // convert virtual column ids to storage column ids + // Convert the virtual column ids to physical column ids. for (auto &column_id : column_ids) { storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); } @@ -33,7 +33,6 @@ PhysicalCreateARTIndex::PhysicalCreateARTIndex(LogicalOperator &op, TableCatalog class CreateARTIndexGlobalSinkState : public GlobalSinkState { public: - //! Global index to be added to the table unique_ptr global_index; }; @@ -43,53 +42,51 @@ class CreateARTIndexLocalSinkState : public LocalSinkState { unique_ptr local_index; ArenaAllocator arena_allocator; - vector keys; + DataChunk key_chunk; + unsafe_vector keys; vector key_column_ids; + + DataChunk row_id_chunk; + unsafe_vector row_ids; }; unique_ptr PhysicalCreateARTIndex::GetGlobalSinkState(ClientContext &context) const { + // Create the global sink state and add the global index. auto state = make_uniq(); - - // create the global index auto &storage = table.GetStorage(); state->global_index = make_uniq(info->index_name, info->constraint_type, storage_ids, TableIOManager::Get(storage), unbound_expressions, storage.db); - return (std::move(state)); } unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionContext &context) const { + // Create the local sink state and add the local index. auto state = make_uniq(context.client); - - // create the local index - auto &storage = table.GetStorage(); state->local_index = make_uniq(info->index_name, info->constraint_type, storage_ids, TableIOManager::Get(storage), unbound_expressions, storage.db); - state->keys = vector(STANDARD_VECTOR_SIZE); + // Initialize the local sink state. + state->keys.resize(STANDARD_VECTOR_SIZE); + state->row_ids.resize(STANDARD_VECTOR_SIZE); state->key_chunk.Initialize(Allocator::Get(context.client), state->local_index->logical_types); - + state->row_id_chunk.Initialize(Allocator::Get(context.client), vector {LogicalType::ROW_TYPE}); for (idx_t i = 0; i < state->key_chunk.ColumnCount(); i++) { state->key_column_ids.push_back(i); } return std::move(state); } -SinkResultType PhysicalCreateARTIndex::SinkUnsorted(Vector &row_identifiers, OperatorSinkInput &input) const { +SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) const { auto &l_state = input.local_state.Cast(); - auto count = l_state.key_chunk.size(); - - // get the corresponding row IDs - row_identifiers.Flatten(count); - auto row_ids = FlatVector::GetData(row_identifiers); + auto row_count = l_state.key_chunk.size(); - // insert the row IDs + // Insert each key and its corresponding row ID. auto &art = l_state.local_index->Cast(); - for (idx_t i = 0; i < count; i++) { - if (!art.Insert(art.tree, l_state.keys[i], 0, row_ids[i])) { + for (idx_t i = 0; i < row_count; i++) { + if (!art.Insert(art.tree, l_state.keys[i], 0, l_state.row_ids[i], art.tree.GetGateStatus())) { throw ConstraintException("Data contains duplicates on indexed column(s)"); } } @@ -97,21 +94,21 @@ SinkResultType PhysicalCreateARTIndex::SinkUnsorted(Vector &row_identifiers, Ope return SinkResultType::NEED_MORE_INPUT; } -SinkResultType PhysicalCreateARTIndex::SinkSorted(Vector &row_identifiers, OperatorSinkInput &input) const { +SinkResultType PhysicalCreateARTIndex::SinkSorted(OperatorSinkInput &input) const { auto &l_state = input.local_state.Cast(); auto &storage = table.GetStorage(); auto &l_index = l_state.local_index; - // create an ART from the chunk + // Construct an ART for this chunk. auto art = make_uniq(info->index_name, l_index->GetConstraintType(), l_index->GetColumnIds(), l_index->table_io_manager, l_index->unbound_expressions, storage.db, l_index->Cast().allocators); - if (!art->ConstructFromSorted(l_state.key_chunk.size(), l_state.keys, row_identifiers)) { + if (!art->Construct(l_state.keys, l_state.row_ids, l_state.key_chunk.size())) { throw ConstraintException("Data contains duplicates on indexed column(s)"); } - // merge into the local ART + // Merge the ART into the local ART. if (!l_index->MergeIndexes(*art)) { throw ConstraintException("Data contains duplicates on indexed column(s)"); } @@ -123,29 +120,26 @@ SinkResultType PhysicalCreateARTIndex::Sink(ExecutionContext &context, DataChunk OperatorSinkInput &input) const { D_ASSERT(chunk.ColumnCount() >= 2); - - // generate the keys for the given input auto &l_state = input.local_state.Cast(); - l_state.key_chunk.ReferenceColumns(chunk, l_state.key_column_ids); l_state.arena_allocator.Reset(); - ART::GenerateKeys(l_state.arena_allocator, l_state.key_chunk, l_state.keys); + l_state.key_chunk.ReferenceColumns(chunk, l_state.key_column_ids); + ART::GenerateKeyVectors(l_state.arena_allocator, l_state.key_chunk, chunk.data[chunk.ColumnCount() - 1], + l_state.keys, l_state.row_ids); - // insert the keys and their corresponding row IDs - auto &row_identifiers = chunk.data[chunk.ColumnCount() - 1]; if (sorted) { - return SinkSorted(row_identifiers, input); + return SinkSorted(input); } - return SinkUnsorted(row_identifiers, input); + return SinkUnsorted(input); } SinkCombineResultType PhysicalCreateARTIndex::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); + auto &g_state = input.global_state.Cast(); + auto &l_state = input.local_state.Cast(); // merge the local index into the global index - if (!gstate.global_index->MergeIndexes(*lstate.local_index)) { + if (!g_state.global_index->MergeIndexes(*l_state.local_index)) { throw ConstraintException("Data contains duplicates on indexed column(s)"); } @@ -161,6 +155,7 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve // vacuum excess memory and verify state.global_index->Vacuum(); D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); + state.global_index->VerifyAllocations(); auto &storage = table.GetStorage(); if (!storage.IsRoot()) { @@ -169,20 +164,22 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve auto &schema = table.schema; info->column_ids = storage_ids; - auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(context), *info, table).get(); - if (!index_entry) { - D_ASSERT(info->on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT); - // index already exists, but error ignored because of IF NOT EXISTS + + // Ensure that the index does not yet exist. + // FIXME: We should early-out prior to creating the index. + if (schema.GetEntry(schema.GetCatalogTransaction(context), CatalogType::INDEX_ENTRY, info->index_name)) { + if (info->on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) { + throw CatalogException("Index with name \"%s\" already exists!", info->index_name); + } + // IF NOT EXISTS on existing index. We are done. return SinkFinalizeType::READY; } + + auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(context), *info, table).get(); + D_ASSERT(index_entry); auto &index = index_entry->Cast(); index.initial_index_size = state.global_index->GetInMemorySize(); - index.info = make_shared_ptr(storage.GetDataTableInfo(), index.name); - for (auto &parsed_expr : info->parsed_expressions) { - index.parsed_expressions.push_back(parsed_expr->Copy()); - } - // add index to storage storage.AddIndex(std::move(state.global_index)); return SinkFinalizeType::READY; diff --git a/src/duckdb/src/execution/operator/set/physical_cte.cpp b/src/duckdb/src/execution/operator/set/physical_cte.cpp index 8804b4a2..fad76bbd 100644 --- a/src/duckdb/src/execution/operator/set/physical_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_cte.cpp @@ -102,12 +102,10 @@ vector> PhysicalCTE::GetSources() const { return children[1]->GetSources(); } -string PhysicalCTE::ParamsToString() const { - string result = ""; - result += "\n[INFOSEPARATOR]\n"; - result += ctename; - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", table_index); +InsertionOrderPreservingMap PhysicalCTE::ParamsToString() const { + InsertionOrderPreservingMap result; + result["CTE Name"] = ctename; + result["Table Index"] = StringUtil::Format("%llu", table_index); return result; } diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp index 406299ac..328e0822 100644 --- a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp @@ -223,12 +223,10 @@ vector> PhysicalRecursiveCTE::GetSources() con return {*this}; } -string PhysicalRecursiveCTE::ParamsToString() const { - string result = ""; - result += "\n[INFOSEPARATOR]\n"; - result += ctename; - result += "\n[INFOSEPARATOR]\n"; - result += StringUtil::Format("idx: %llu", table_index); +InsertionOrderPreservingMap PhysicalRecursiveCTE::ParamsToString() const { + InsertionOrderPreservingMap result; + result["CTE Name"] = ctename; + result["Table Index"] = StringUtil::Format("%llu", table_index); return result; } diff --git a/src/duckdb/src/execution/operator/set/physical_union.cpp b/src/duckdb/src/execution/operator/set/physical_union.cpp index 81e5b49e..4954dc87 100644 --- a/src/duckdb/src/execution/operator/set/physical_union.cpp +++ b/src/duckdb/src/execution/operator/set/physical_union.cpp @@ -40,20 +40,34 @@ void PhysicalUnion::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipelin } } - // create a union pipeline that is identical to 'current' + // create a union pipeline that has identical dependencies to 'current' auto &union_pipeline = meta_pipeline.CreateUnionPipeline(current, order_matters); // continue with the current pipeline children[0]->BuildPipelines(current, meta_pipeline); - if (order_matters) { - // order matters, so 'union_pipeline' must come after all pipelines created by building out 'current' - meta_pipeline.AddDependenciesFrom(union_pipeline, union_pipeline, false); + vector> dependencies; + optional_ptr last_child_ptr; + const auto can_saturate_threads = children[0]->CanSaturateThreads(current.GetClientContext()); + if (order_matters || can_saturate_threads) { + // we add dependencies if order matters: union_pipeline comes after all pipelines created by building current + dependencies = meta_pipeline.AddDependenciesFrom(union_pipeline, union_pipeline, false); + // we also add dependencies if the LHS child can saturate all available threads + // in that case, we recursively make all RHS children depend on the LHS. + // This prevents breadth-first plan evaluation + if (can_saturate_threads) { + last_child_ptr = meta_pipeline.GetLastChild(); + } } // build the union pipeline children[1]->BuildPipelines(union_pipeline, meta_pipeline); + if (last_child_ptr) { + // the pointer was set, set up the dependencies + meta_pipeline.AddRecursiveDependencies(dependencies, *last_child_ptr); + } + // Assign proper batch index to the union pipeline // This needs to happen after the pipelines have been built because unions can be nested meta_pipeline.AssignNextBatchIndex(union_pipeline); diff --git a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp index ba8e366e..4fa6f08f 100644 --- a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp @@ -1,4 +1,5 @@ #include "duckdb/execution/perfect_aggregate_hashtable.hpp" + #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/row_operations/row_operations.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -25,11 +26,11 @@ PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, All tuple_size = layout.GetRowWidth(); // allocate and null initialize the data - owned_data = make_unsafe_uniq_array(tuple_size * total_groups); + owned_data = make_unsafe_uniq_array_uninitialized(tuple_size * total_groups); data = owned_data.get(); // set up the empty payloads for every tuple, and initialize the "occupied" flag to false - group_is_set = make_unsafe_uniq_array(total_groups); + group_is_set = make_unsafe_uniq_array_uninitialized(total_groups); memset(group_is_set.get(), 0, total_groups * sizeof(bool)); // initialize the hash table for each entry diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp index 4934789e..3fe08ddf 100644 --- a/src/duckdb/src/execution/physical_operator.cpp +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -1,6 +1,7 @@ #include "duckdb/execution/physical_operator.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/common/render_tree.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/tree_renderer.hpp" #include "duckdb/execution/execution_context.hpp" @@ -9,8 +10,8 @@ #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/pipeline.hpp" #include "duckdb/parallel/thread_context.hpp" -#include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" +#include "duckdb/storage/buffer_manager.hpp" namespace duckdb { @@ -18,9 +19,12 @@ string PhysicalOperator::GetName() const { return PhysicalOperatorToString(type); } -string PhysicalOperator::ToString() const { - TreeRenderer renderer; - return renderer.ToString(*this); +string PhysicalOperator::ToString(ExplainFormat format) const { + auto renderer = TreeRenderer::CreateRenderer(format); + stringstream ss; + auto tree = RenderTree::CreateRenderTree(*this); + renderer->ToStream(*tree, ss); + return ss.str(); } // LCOV_EXCL_START @@ -37,6 +41,40 @@ vector> PhysicalOperator::GetChildren() const return result; } +void PhysicalOperator::SetEstimatedCardinality(InsertionOrderPreservingMap &result, + idx_t estimated_cardinality) { + result[RenderTreeNode::ESTIMATED_CARDINALITY] = StringUtil::Format("%llu", estimated_cardinality); +} + +idx_t PhysicalOperator::EstimatedThreadCount() const { + idx_t result = 0; + if (children.empty()) { + // Terminal operator, e.g., base table, these decide the degree of parallelism of pipelines + result = MaxValue(estimated_cardinality / (Storage::ROW_GROUP_SIZE * 2), 1); + } else if (type == PhysicalOperatorType::UNION) { + // We can run union pipelines in parallel, so we sum up the thread count of the children + for (auto &child : children) { + result += child->EstimatedThreadCount(); + } + } else { + // For other operators we take the maximum of the children + for (auto &child : children) { + result = MaxValue(child->EstimatedThreadCount(), result); + } + } + return result; +} + +bool PhysicalOperator::CanSaturateThreads(ClientContext &context) const { +#ifdef DEBUG + // In debug mode we always return true here so that the code that depends on it is well-tested + return true; +#else + const auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); + return EstimatedThreadCount() >= num_threads; +#endif +} + //===--------------------------------------------------------------------===// // Operator //===--------------------------------------------------------------------===// @@ -102,6 +140,9 @@ SinkCombineResultType PhysicalOperator::Combine(ExecutionContext &context, Opera return SinkCombineResultType::FINISHED; } +void PhysicalOperator::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { +} + SinkFinalizeType PhysicalOperator::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, OperatorSinkFinalizeInput &input) const { return SinkFinalizeType::READY; diff --git a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp index 743f8189..31fcf3ba 100644 --- a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp @@ -91,6 +91,10 @@ void CheckForPerfectJoinOpt(LogicalComparisonJoin &op, PerfectHashJoinStats &joi !ExtractNumericValue(NumericStats::Max(stats_build), max_value)) { return; } + if (max_value < min_value) { + // empty table + return; + } int64_t build_range; if (!TrySubtractOperator::Operation(max_value, min_value, build_range)) { return; @@ -184,26 +188,33 @@ unique_ptr PhysicalPlanGenerator::PlanComparisonJoin(LogicalCo default: break; } + auto &client_config = ClientConfig::GetConfig(context); // TODO: Extend PWMJ to handle all comparisons and projection maps - const auto prefer_range_joins = (ClientConfig::GetConfig(context).prefer_range_joins && can_iejoin); + const auto prefer_range_joins = client_config.prefer_range_joins && can_iejoin; unique_ptr plan; if (has_equality && !prefer_range_joins) { // Equality join with small number of keys : possible perfect join optimization PerfectHashJoinStats perfect_join_stats; CheckForPerfectJoinOpt(op, perfect_join_stats); - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.left_projection_map, op.right_projection_map, - std::move(op.mark_types), op.estimated_cardinality, perfect_join_stats); + plan = + make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), op.join_type, + op.left_projection_map, op.right_projection_map, std::move(op.mark_types), + op.estimated_cardinality, perfect_join_stats, std::move(op.filter_pushdown)); } else { - static constexpr const idx_t NESTED_LOOP_JOIN_THRESHOLD = 5; - if (left->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD || - right->estimated_cardinality <= NESTED_LOOP_JOIN_THRESHOLD) { + if (left->estimated_cardinality <= client_config.nested_loop_join_threshold || + right->estimated_cardinality <= client_config.nested_loop_join_threshold) { can_iejoin = false; can_merge = false; } + if (can_merge && can_iejoin) { + if (left->estimated_cardinality <= client_config.merge_join_threshold || + right->estimated_cardinality <= client_config.merge_join_threshold) { + can_iejoin = false; + } + } if (can_iejoin) { plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), op.join_type, op.estimated_cardinality); diff --git a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp index c3194e25..a2981f61 100644 --- a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp @@ -1,5 +1,5 @@ -#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" #include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" +#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/planner/operator/logical_copy_to_file.hpp" @@ -16,8 +16,8 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile auto base = StringUtil::GetFileName(op.file_path); op.file_path = fs.JoinPath(path, "tmp_" + base); } - if (op.per_thread_output || op.file_size_bytes.IsValid() || op.partition_output || !op.partition_columns.empty() || - op.overwrite_mode != CopyOverwriteMode::COPY_ERROR_ON_CONFLICT) { + if (op.per_thread_output || op.file_size_bytes.IsValid() || op.rotate || op.partition_output || + !op.partition_columns.empty() || op.overwrite_mode != CopyOverwriteMode::COPY_ERROR_ON_CONFLICT) { // hive-partitioning/per-thread output does not care about insertion order, and does not support batch indexes preserve_insertion_order = false; supports_batch_index = false; @@ -36,8 +36,10 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile copy->file_path = op.file_path; copy->use_tmp_file = op.use_tmp_file; copy->children.push_back(std::move(plan)); + copy->return_type = op.return_type; return std::move(copy); } + // COPY from select statement to file auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), op.estimated_cardinality); copy->file_path = op.file_path; @@ -49,8 +51,11 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile if (op.file_size_bytes.IsValid()) { copy->file_size_bytes = op.file_size_bytes; } + copy->rotate = op.rotate; + copy->return_type = op.return_type; copy->partition_output = op.partition_output; copy->partition_columns = op.partition_columns; + copy->write_partition_columns = op.write_partition_columns; copy->names = op.names; copy->expected_types = op.expected_types; copy->parallel = mode == CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp index ba4246b5..e545586c 100644 --- a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp @@ -7,14 +7,19 @@ #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" + namespace duckdb { -static void GatherDelimScans(const PhysicalOperator &op, vector> &delim_scans) { +static void GatherDelimScans(PhysicalOperator &op, vector> &delim_scans, + idx_t delim_index) { if (op.type == PhysicalOperatorType::DELIM_SCAN) { + auto &scan = op.Cast(); + scan.delim_index = optional_idx(delim_index); delim_scans.push_back(op); } for (auto &child : op.children) { - GatherDelimScans(*child, delim_scans); + GatherDelimScans(*child, delim_scans, delim_index); } } @@ -27,7 +32,7 @@ unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalCompari // first gather the scans on the duplicate eliminated data set from the delim side const idx_t delim_idx = op.delim_flipped ? 0 : 1; vector> delim_scans; - GatherDelimScans(*plan->children[delim_idx], delim_scans); + GatherDelimScans(*plan->children[delim_idx], delim_scans, ++this->delim_index); if (delim_scans.empty()) { // no duplicate eliminated scans in the delim side! // in this case we don't need to create a delim join @@ -45,14 +50,16 @@ unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalCompari // now create the duplicate eliminated join unique_ptr delim_join; if (op.delim_flipped) { - delim_join = - make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality); + delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality, + optional_idx(this->delim_index)); } else { - delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality); + delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality, + optional_idx(this->delim_index)); } // we still have to create the DISTINCT clause that is used to generate the duplicate eliminated chunk delim_join->distinct = make_uniq(context, delim_types, std::move(distinct_expressions), std::move(distinct_groups), op.estimated_cardinality); + return std::move(delim_join); } diff --git a/src/duckdb/src/execution/physical_plan/plan_explain.cpp b/src/duckdb/src/execution/physical_plan/plan_explain.cpp index 867aa4ac..3f9d2a4b 100644 --- a/src/duckdb/src/execution/physical_plan/plan_explain.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_explain.cpp @@ -10,15 +10,15 @@ namespace duckdb { unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExplain &op) { D_ASSERT(op.children.size() == 1); - auto logical_plan_opt = op.children[0]->ToString(); + auto logical_plan_opt = op.children[0]->ToString(op.explain_format); auto plan = CreatePlan(*op.children[0]); if (op.explain_type == ExplainType::EXPLAIN_ANALYZE) { - auto result = make_uniq(op.types); + auto result = make_uniq(op.types, op.explain_format); result->children.push_back(std::move(plan)); return std::move(result); } - op.physical_plan = plan->ToString(); + op.physical_plan = plan->ToString(op.explain_format); // the output of the explain vector keys, values; switch (ClientConfig::GetConfig(context).explain_output_type) { diff --git a/src/duckdb/src/execution/physical_plan/plan_get.cpp b/src/duckdb/src/execution/physical_plan/plan_get.cpp index e7ec5437..056d291f 100644 --- a/src/duckdb/src/execution/physical_plan/plan_get.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_get.cpp @@ -1,15 +1,20 @@ +#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" +#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/execution/operator/projection/physical_tableinout_function.hpp" #include "duckdb/execution/operator/scan/physical_table_scan.hpp" #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/function/table/table_scan.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/execution/operator/filter/physical_filter.hpp" namespace duckdb { -unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, vector &column_ids) { +unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, const vector &column_ids) { // create the table filter map auto table_filter_set = make_uniq(); for (auto &table_filter : table_filters.filters) { @@ -30,12 +35,43 @@ unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, v } unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { + auto column_ids = op.GetColumnIds(); if (!op.children.empty()) { + auto child_node = CreatePlan(std::move(op.children[0])); // this is for table producing functions that consume subquery results - D_ASSERT(op.children.size() == 1); - auto node = make_uniq(op.types, op.function, std::move(op.bind_data), op.column_ids, + // push a projection node with casts if required + if (child_node->types.size() < op.input_table_types.size()) { + throw InternalException( + "Mismatch between input table types and child node types - expected %llu but got %llu", + op.input_table_types.size(), child_node->types.size()); + } + vector return_types; + vector> expressions; + bool any_cast_required = false; + for (idx_t proj_idx = 0; proj_idx < child_node->types.size(); proj_idx++) { + auto ref = make_uniq(child_node->types[proj_idx], proj_idx); + auto &target_type = + proj_idx < op.input_table_types.size() ? op.input_table_types[proj_idx] : child_node->types[proj_idx]; + if (child_node->types[proj_idx] != target_type) { + // cast is required - push a cast + any_cast_required = true; + auto cast = BoundCastExpression::AddCastToType(context, std::move(ref), target_type); + expressions.push_back(std::move(cast)); + } else { + expressions.push_back(std::move(ref)); + } + return_types.push_back(target_type); + } + if (any_cast_required) { + auto proj = make_uniq(std::move(return_types), std::move(expressions), + child_node->estimated_cardinality); + proj->children.push_back(std::move(child_node)); + child_node = std::move(proj); + } + + auto node = make_uniq(op.types, op.function, std::move(op.bind_data), column_ids, op.estimated_cardinality, std::move(op.projected_input)); - node->children.push_back(CreatePlan(std::move(op.children[0]))); + node->children.push_back(std::move(child_node)); return std::move(node); } if (!op.projected_input.empty()) { @@ -44,23 +80,66 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { unique_ptr table_filters; if (!op.table_filters.filters.empty()) { - table_filters = CreateTableFilterSet(op.table_filters, op.column_ids); + table_filters = CreateTableFilterSet(op.table_filters, column_ids); } if (op.function.dependency) { op.function.dependency(dependencies, op.bind_data.get()); } + unique_ptr filter; + + auto &projection_ids = op.projection_ids; + + if (table_filters && op.function.supports_pushdown_type) { + vector> select_list; + unique_ptr unsupported_filter; + unordered_set to_remove; + for (auto &entry : table_filters->filters) { + auto column_id = column_ids[entry.first]; + auto &type = op.returned_types[column_id]; + if (!op.function.supports_pushdown_type(type)) { + idx_t column_id_filter = entry.first; + bool found_projection = false; + for (idx_t i = 0; i < projection_ids.size(); i++) { + if (column_ids[projection_ids[i]] == column_ids[entry.first]) { + column_id_filter = i; + found_projection = true; + break; + } + } + if (!found_projection) { + projection_ids.push_back(entry.first); + column_id_filter = projection_ids.size() - 1; + } + auto column = make_uniq(type, column_id_filter); + select_list.push_back(entry.second->ToExpression(*column)); + to_remove.insert(entry.first); + } + } + for (auto &col : to_remove) { + table_filters->filters.erase(col); + } + + if (!select_list.empty()) { + vector filter_types; + for (auto &c : projection_ids) { + filter_types.push_back(op.returned_types[column_ids[c]]); + } + filter = make_uniq(filter_types, std::move(select_list), op.estimated_cardinality); + } + } + op.ResolveOperatorTypes(); // create the table scan node if (!op.function.projection_pushdown) { // function does not support projection pushdown - auto node = make_uniq(op.returned_types, op.function, std::move(op.bind_data), - op.returned_types, op.column_ids, vector(), op.names, - std::move(table_filters), op.estimated_cardinality, op.extra_info); + auto node = make_uniq( + op.returned_types, op.function, std::move(op.bind_data), op.returned_types, column_ids, vector(), + op.names, std::move(table_filters), op.estimated_cardinality, op.extra_info, std::move(op.parameters)); // first check if an additional projection is necessary - if (op.column_ids.size() == op.returned_types.size()) { + if (column_ids.size() == op.returned_types.size()) { bool projection_necessary = false; - for (idx_t i = 0; i < op.column_ids.size(); i++) { - if (op.column_ids[i] != i) { + for (idx_t i = 0; i < column_ids.size(); i++) { + if (column_ids[i] != i) { projection_necessary = true; break; } @@ -68,14 +147,17 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { if (!projection_necessary) { // a projection is not necessary if all columns have been requested in-order // in that case we just return the node - + if (filter) { + filter->children.push_back(std::move(node)); + return std::move(filter); + } return std::move(node); } } // push a projection on top that does the projection vector types; vector> expressions; - for (auto &column_id : op.column_ids) { + for (auto &column_id : column_ids) { if (column_id == COLUMN_IDENTIFIER_ROW_ID) { types.emplace_back(LogicalType::BIGINT); expressions.push_back(make_uniq(Value::BIGINT(0))); @@ -85,15 +167,25 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { expressions.push_back(make_uniq(type, column_id)); } } - - auto projection = + unique_ptr projection = make_uniq(std::move(types), std::move(expressions), op.estimated_cardinality); - projection->children.push_back(std::move(node)); + if (filter) { + filter->children.push_back(std::move(node)); + projection->children.push_back(std::move(filter)); + } else { + projection->children.push_back(std::move(node)); + } return std::move(projection); } else { - return make_uniq(op.types, op.function, std::move(op.bind_data), op.returned_types, - op.column_ids, op.projection_ids, op.names, std::move(table_filters), - op.estimated_cardinality, op.extra_info); + auto node = make_uniq(op.types, op.function, std::move(op.bind_data), op.returned_types, + column_ids, op.projection_ids, op.names, std::move(table_filters), + op.estimated_cardinality, op.extra_info, std::move(op.parameters)); + node->dynamic_filters = op.dynamic_filters; + if (filter) { + filter->children.push_back(std::move(node)); + return std::move(filter); + } + return std::move(node); } } diff --git a/src/duckdb/src/execution/physical_plan/plan_limit.cpp b/src/duckdb/src/execution/physical_plan/plan_limit.cpp index e10764cc..508f1c88 100644 --- a/src/duckdb/src/execution/physical_plan/plan_limit.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_limit.cpp @@ -7,10 +7,27 @@ namespace duckdb { -bool UseBatchLimit(BoundLimitNode &limit_val, BoundLimitNode &offset_val) { +bool UseBatchLimit(PhysicalOperator &child_node, BoundLimitNode &limit_val, BoundLimitNode &offset_val) { #ifdef DUCKDB_ALTERNATIVE_VERIFY return true; #else + // we only want to use the batch limit when we are executing a complex query (e.g. involving a filter or join) + // if we are doing a limit over a table scan we are otherwise scanning a lot of rows just to throw them away + reference current_ref(child_node); + bool finished = false; + while (!finished) { + auto ¤t_op = current_ref.get(); + switch (current_op.type) { + case PhysicalOperatorType::TABLE_SCAN: + return false; + case PhysicalOperatorType::PROJECTION: + current_ref = *current_op.children[0]; + break; + default: + finished = true; + break; + } + } // we only use batch limit when we are computing a small amount of values // as the batch limit materializes this many rows PER thread static constexpr const idx_t BATCH_LIMIT_THRESHOLD = 10000; @@ -48,7 +65,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimit &op) op.estimated_cardinality, true); } else { // maintaining insertion order is important - if (UseBatchIndex(*plan) && UseBatchLimit(op.limit_val, op.offset_val)) { + if (UseBatchIndex(*plan) && UseBatchLimit(*plan, op.limit_val, op.offset_val)) { // source supports batch index: use parallel batch limit limit = make_uniq(op.types, std::move(op.limit_val), std::move(op.offset_val), op.estimated_cardinality); diff --git a/src/duckdb/src/execution/physical_plan/plan_set.cpp b/src/duckdb/src/execution/physical_plan/plan_set.cpp index 9325719b..fe09cb31 100644 --- a/src/duckdb/src/execution/physical_plan/plan_set.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_set.cpp @@ -1,10 +1,19 @@ #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/planner/operator/logical_set.hpp" #include "duckdb/execution/operator/helper/physical_set.hpp" +#include "duckdb/execution/operator/helper/physical_set_variable.hpp" namespace duckdb { unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSet &op) { + if (!op.children.empty()) { + // set variable + auto child = CreatePlan(*op.children[0]); + auto set_variable = make_uniq(std::move(op.name), op.estimated_cardinality); + set_variable->children.push_back(std::move(child)); + return std::move(set_variable); + } + // set config setting return make_uniq(op.name, op.value, op.scope, op.estimated_cardinality); } diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp index bab25f64..f294d932 100644 --- a/src/duckdb/src/execution/physical_plan/plan_window.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_window.cpp @@ -2,6 +2,7 @@ #include "duckdb/execution/operator/aggregate/physical_window.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/execution/physical_plan_generator.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/planner/operator/logical_window.hpp" @@ -28,10 +29,11 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalWindow &op types.resize(input_width); // Identify streaming windows + const bool enable_optimizer = ClientConfig::GetConfig(context).enable_optimizer; vector blocking_windows; vector streaming_windows; for (idx_t expr_idx = 0; expr_idx < op.expressions.size(); expr_idx++) { - if (PhysicalStreamingWindow::IsStreamingFunction(op.expressions[expr_idx])) { + if (enable_optimizer && PhysicalStreamingWindow::IsStreamingFunction(context, op.expressions[expr_idx])) { streaming_windows.push_back(expr_idx); } else { blocking_windows.push_back(expr_idx); diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp index 6b223f72..9dead090 100644 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -41,13 +41,13 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(unique_ptrResolveOperatorTypes(); profiler.EndPhase(); @@ -56,7 +56,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(unique_ptr data_p) : state(AggregatePartitionState::READY_TO_FINALIZE), data(std::move(data_p)), progress(0) { } - mutex lock; AggregatePartitionState state; unique_ptr data; atomic progress; - - vector blocked_tasks; }; class RadixHTGlobalSinkState; @@ -172,8 +170,6 @@ class RadixHTGlobalSinkState : public GlobalSinkState { //! If any thread has called combine atomic any_combined; - //! Lock for uncombined_data/stored_allocators - mutex lock; //! Uncombined partitioned data that will be put into the AggregatePartitions unique_ptr uncombined_data; //! Allocators used during the Sink/Finalize @@ -199,20 +195,25 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R any_combined(false), finalize_done(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0), max_partition_size(0) { - auto tuples_per_block = Storage::BLOCK_ALLOC_SIZE / radix_ht.GetLayout().GetRowWidth(); + // Compute minimum reservation + auto block_alloc_size = BufferManager::GetBufferManager(context).GetBlockAllocSize(); + auto tuples_per_block = block_alloc_size / radix_ht.GetLayout().GetRowWidth(); idx_t ht_count = - NumericCast(static_cast(config.sink_capacity) / GroupedAggregateHashTable::LOAD_FACTOR); + LossyNumericCast(static_cast(config.sink_capacity) / GroupedAggregateHashTable::LOAD_FACTOR); auto num_partitions = RadixPartitioning::NumberOfPartitions(config.GetRadixBits()); auto count_per_partition = ht_count / num_partitions; auto blocks_per_partition = (count_per_partition + tuples_per_block) / tuples_per_block + 1; - auto ht_size = blocks_per_partition * Storage::BLOCK_ALLOC_SIZE + config.sink_capacity * sizeof(aggr_ht_entry_t); + if (!radix_ht.GetLayout().AllConstant()) { + blocks_per_partition += 2; + } + auto ht_size = blocks_per_partition * block_alloc_size + config.sink_capacity * sizeof(ht_entry_t); // This really is the minimum reservation that we can do auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); auto minimum_reservation = num_threads * ht_size; temporary_memory_state->SetMinimumReservation(minimum_reservation); - temporary_memory_state->SetRemainingSize(context, minimum_reservation); + temporary_memory_state->SetRemainingSizeAndUpdateReservation(context, minimum_reservation); } RadixHTGlobalSinkState::~RadixHTGlobalSinkState() { @@ -233,7 +234,7 @@ void RadixHTGlobalSinkState::Destroy() { } // There are aggregates with destructors: Call the destructor for each of the aggregates - lock_guard guard(lock); + auto guard = Lock(); RowOperationsState row_state(*stored_allocators.back()); for (auto &partition : partitions) { auto &data_collection = *partition->data; @@ -274,7 +275,7 @@ void RadixHTConfig::SetRadixBitsInternal(const idx_t radix_bits_p, bool external return; } - lock_guard guard(sink.lock); + auto guard = sink.Lock(); if (sink_radix_bits >= radix_bits_p || sink.any_combined) { return; } @@ -309,9 +310,9 @@ idx_t RadixHTConfig::SinkCapacity(ClientContext &context) { const auto cache_per_active_thread = L1_CACHE_SIZE + L2_CACHE_SIZE + total_shared_cache_size / active_threads; // Divide cache per active thread by entry size, round up to next power of two, to get capacity - const auto size_per_entry = sizeof(aggr_ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR; + const auto size_per_entry = sizeof(ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR; const auto capacity = - NextPowerOfTwo(NumericCast(static_cast(cache_per_active_thread) / size_per_entry)); + NextPowerOfTwo(LossyNumericCast(static_cast(cache_per_active_thread) / size_per_entry)); // Capacity must be at least the minimum capacity return MaxValue(capacity, GroupedAggregateHashTable::InitialCapacity()); @@ -369,19 +370,19 @@ bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra // Check if we're approaching the memory limit auto &temporary_memory_state = *gstate.temporary_memory_state; - const auto total_size = partitioned_data->SizeInBytes() + ht.Capacity() * sizeof(aggr_ht_entry_t); + const auto total_size = partitioned_data->SizeInBytes() + ht.Capacity() * sizeof(ht_entry_t); idx_t thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; if (total_size > thread_limit) { // We're over the thread memory limit if (!gstate.external) { // We haven't yet triggered out-of-core behavior, but maybe we don't have to, grab the lock and check again - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; if (total_size > thread_limit) { // Out-of-core would be triggered below, try to increase the reservation auto remaining_size = MaxValue(gstate.number_of_threads * total_size, temporary_memory_state.GetRemainingSize()); - temporary_memory_state.SetRemainingSize(context, 2 * remaining_size); + temporary_memory_state.SetRemainingSizeAndUpdateReservation(context, 2 * remaining_size); thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; } } @@ -413,9 +414,10 @@ bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra const auto current_radix_bits = RadixPartitioning::RadixBits(partition_count); D_ASSERT(current_radix_bits <= config.GetRadixBits()); + const auto block_size = BufferManager::GetBufferManager(context).GetBlockSize(); const auto row_size_per_partition = partitioned_data->Count() * partitioned_data->GetLayout().GetRowWidth() / partition_count; - if (row_size_per_partition > NumericCast(config.BLOCK_FILL_FACTOR * Storage::BLOCK_SIZE)) { + if (row_size_per_partition > LossyNumericCast(config.BLOCK_FILL_FACTOR * static_cast(block_size))) { // We crossed our block filling threshold, try to increment radix bits config.SetRadixBits(current_radix_bits + config.REPARTITION_RADIX_BITS); } @@ -498,7 +500,7 @@ void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkSta lstate.abandoned_data = std::move(ht.GetPartitionedData()); } - lock_guard guard(gstate.lock); + auto guard = gstate.Lock(); if (gstate.uncombined_data) { gstate.uncombined_data->Combine(*lstate.abandoned_data); } else { @@ -524,7 +526,7 @@ void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState auto &partition = uncombined_partition_data[i]; auto partition_size = partition->SizeInBytes() + - GroupedAggregateHashTable::GetCapacityForCount(partition->Count()) * sizeof(aggr_ht_entry_t); + GroupedAggregateHashTable::GetCapacityForCount(partition->Count()) * sizeof(ht_entry_t); gstate.max_partition_size = MaxValue(gstate.max_partition_size, partition_size); gstate.partitions.emplace_back(make_uniq(std::move(partition))); @@ -540,10 +542,8 @@ void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState // Minimum of combining one partition at a time gstate.temporary_memory_state->SetMinimumReservation(gstate.max_partition_size); - // Maximum of combining all partitions - auto max_threads = MinValue(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()), - gstate.partitions.size()); - gstate.temporary_memory_state->SetRemainingSize(context, max_threads * gstate.max_partition_size); + // Set size to 0 until the scan actually starts + gstate.temporary_memory_state->SetZero(); gstate.finalized = true; } @@ -556,14 +556,17 @@ idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const { return 0; } + const auto max_threads = MinValue( + NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), sink.partitions.size()); + sink.temporary_memory_state->SetRemainingSizeAndUpdateReservation(sink.context, + max_threads * sink.max_partition_size); + // This many partitions will fit given our reservation (at least 1)) - auto partitions_fit = MaxValue(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1); - // Maximum is either the number of partitions, or the number of threads - auto max_possible = MinValue( - sink.partitions.size(), NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())); + const auto partitions_fit = + MaxValue(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1); // Mininum of the two - return MinValue(partitions_fit, max_possible); + return MinValue(partitions_fit, max_threads); } void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) { @@ -652,18 +655,16 @@ RadixHTGlobalSourceState::RadixHTGlobalSourceState(ClientContext &context_p, con SourceResultType RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate, InterruptState &interrupt_state) { // First, try to get a partition index - lock_guard gstate_guard(sink.lock); - if (finished) { - return SourceResultType::FINISHED; - } - if (task_idx == sink.partitions.size()) { + auto guard = sink.Lock(); + if (finished || task_idx == sink.partitions.size()) { + lstate.ht.reset(); return SourceResultType::FINISHED; } lstate.task_idx = task_idx++; // We got a partition index auto &partition = *sink.partitions[lstate.task_idx]; - auto partition_lock = unique_lock(partition.lock); + auto partition_guard = partition.Lock(); switch (partition.state) { case AggregatePartitionState::READY_TO_FINALIZE: partition.state = AggregatePartitionState::FINALIZE_IN_PROGRESS; @@ -672,8 +673,7 @@ SourceResultType RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &si case AggregatePartitionState::FINALIZE_IN_PROGRESS: lstate.task = RadixHTSourceTaskType::SCAN; lstate.scan_status = RadixHTScanStatus::INIT; - partition.blocked_tasks.push_back(interrupt_state); - return SourceResultType::BLOCKED; + return partition.BlockSource(partition_guard, interrupt_state); case AggregatePartitionState::READY_TO_SCAN: lstate.task = RadixHTSourceTaskType::SCAN; lstate.scan_status = RadixHTScanStatus::INIT; @@ -721,10 +721,10 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob // However, we will limit the initial capacity so we don't do a huge over-allocation const auto n_threads = NumericCast(TaskScheduler::GetScheduler(gstate.context).NumberOfThreads()); const auto memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory(); - const idx_t thread_limit = NumericCast(0.6 * double(memory_limit) / double(n_threads)); + const idx_t thread_limit = LossyNumericCast(0.6 * double(memory_limit) / double(n_threads)); const idx_t size_per_entry = partition.data->SizeInBytes() / MaxValue(partition.data->Count(), 1) + - idx_t(GroupedAggregateHashTable::LOAD_FACTOR * sizeof(aggr_ht_entry_t)); + idx_t(GroupedAggregateHashTable::LOAD_FACTOR * sizeof(ht_entry_t)); // but not lower than the initial capacity const auto capacity_limit = MaxValue(NextPowerOfTwo(thread_limit / size_per_entry), GroupedAggregateHashTable::InitialCapacity()); @@ -748,22 +748,22 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob partition.data->Combine(*ht->GetPartitionedData()->GetPartitions()[0]); // Update thread-global state - lock_guard global_guard(sink.lock); + auto guard = sink.Lock(); sink.stored_allocators.emplace_back(ht->GetAggregateAllocator()); + if (task_idx == sink.partitions.size()) { + ht.reset(); + } const auto finalizes_done = ++sink.finalize_done; D_ASSERT(finalizes_done <= sink.partitions.size()); if (finalizes_done == sink.partitions.size()) { // All finalizes are done, set remaining size to 0 - sink.temporary_memory_state->SetRemainingSize(sink.context, 0); + sink.temporary_memory_state->SetZero(); } // Update partition state - lock_guard partition_guard(partition.lock); + auto partition_guard = partition.Lock(); partition.state = AggregatePartitionState::READY_TO_SCAN; - for (auto &blocked_task : partition.blocked_tasks) { - blocked_task.Callback(); - } - partition.blocked_tasks.clear(); + partition.UnblockTasks(partition_guard); // This thread will scan the partition task = RadixHTSourceTaskType::SCAN; @@ -788,7 +788,7 @@ void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSo data_collection.Reset(); } scan_status = RadixHTScanStatus::DONE; - lock_guard gstate_guard(sink.lock); + auto guard = sink.Lock(); if (++gstate.task_done == sink.partitions.size()) { gstate.finished = true; } @@ -865,8 +865,8 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D for (idx_t i = 0; i < op.aggregates.size(); i++) { D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = op.aggregates[i]->Cast(); - auto aggr_state = make_unsafe_uniq_array(aggr.function.state_size()); - aggr.function.initialize(aggr_state.get()); + auto aggr_state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); + aggr.function.initialize(aggr.function, aggr_state.get()); AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get()))); diff --git a/src/duckdb/src/execution/reservoir_sample.cpp b/src/duckdb/src/execution/reservoir_sample.cpp index fc711d91..eb20982b 100644 --- a/src/duckdb/src/execution/reservoir_sample.cpp +++ b/src/duckdb/src/execution/reservoir_sample.cpp @@ -239,10 +239,10 @@ void ReservoirSamplePercentage::Finalize() { // Imagine sampling 70% of 100 rows (so 70 rows). We allocate sample_percentage * RESERVOIR_THRESHOLD // ----------------------------------------- auto sampled_more_than_required = - current_count > sample_percentage * RESERVOIR_THRESHOLD || finished_samples.empty(); + static_cast(current_count) > sample_percentage * RESERVOIR_THRESHOLD || finished_samples.empty(); if (current_count > 0 && sampled_more_than_required) { // create a new sample - auto new_sample_size = idx_t(round(sample_percentage * current_count)); + auto new_sample_size = idx_t(round(sample_percentage * static_cast(current_count))); auto new_sample = make_uniq(allocator, new_sample_size, random.NextRandomInteger()); while (true) { auto chunk = current_sample->GetChunk(); diff --git a/src/duckdb/src/execution/window_executor.cpp b/src/duckdb/src/execution/window_executor.cpp index 4d6b9b09..56397f7a 100644 --- a/src/duckdb/src/execution/window_executor.cpp +++ b/src/duckdb/src/execution/window_executor.cpp @@ -7,6 +7,75 @@ namespace duckdb { +//===--------------------------------------------------------------------===// +// WindowDataChunk +//===--------------------------------------------------------------------===// +bool WindowDataChunk::IsSimple(const Vector &v) { + switch (v.GetType().InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::UINT8: + case PhysicalType::INT8: + case PhysicalType::UINT16: + case PhysicalType::INT16: + case PhysicalType::UINT32: + case PhysicalType::INT32: + case PhysicalType::UINT64: + case PhysicalType::INT64: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + case PhysicalType::INTERVAL: + case PhysicalType::UINT128: + case PhysicalType::INT128: + return true; + case PhysicalType::LIST: + case PhysicalType::STRUCT: + case PhysicalType::ARRAY: + case PhysicalType::VARCHAR: + case PhysicalType::BIT: + return false; + default: + break; + } + + throw InternalException("Unsupported type for WindowDataChunk"); +} + +WindowDataChunk::WindowDataChunk(DataChunk &chunk) : chunk(chunk) { +} + +void WindowDataChunk::Initialize(Allocator &allocator, const vector &types, idx_t capacity) { + vector new_locks(types.size()); + locks.swap(new_locks); + chunk.Initialize(allocator, types, capacity); + chunk.SetCardinality(capacity); + + is_simple.clear(); + for (const auto &v : chunk.data) { + is_simple.push_back(IsSimple(v)); + } +} + +void WindowDataChunk::Copy(DataChunk &input, idx_t begin) { + const auto source_count = input.size(); + const idx_t end = begin + source_count; + const idx_t count = chunk.size(); + D_ASSERT(end <= count); + // Can we overwrite the validity mask in parallel? + bool aligned = IsMaskAligned(begin, end, count); + for (column_t i = 0; i < chunk.data.size(); ++i) { + auto &src = input.data[i]; + auto &dst = chunk.data[i]; + UnifiedVectorFormat sdata; + src.ToUnifiedFormat(count, sdata); + if (is_simple[i] && aligned && sdata.validity.AllValid()) { + VectorOperations::Copy(src, dst, source_count, 0, begin); + } else { + lock_guard column_guard(locks[i]); + VectorOperations::Copy(src, dst, source_count, 0, begin); + } + } +} + static idx_t FindNextStart(const ValidityMask &mask, idx_t l, const idx_t r, idx_t &n) { if (mask.AllValid()) { auto start = MinValue(l + n - 1, r); @@ -93,6 +162,26 @@ static void CopyCell(const DataChunk &chunk, idx_t column, idx_t index, Vector & VectorOperations::Copy(source, target, index + 1, index, target_offset); } +//===--------------------------------------------------------------------===// +// WindowInputColumn +//===--------------------------------------------------------------------===// +WindowInputColumn::WindowInputColumn(optional_ptr expr_p, ClientContext &context, idx_t count) + : expr(expr_p), scalar(expr ? expr->IsScalar() : true), count(count), wtarget(target) { + + if (expr) { + vector types; + types.emplace_back(expr->return_type); + wtarget.Initialize(Allocator::Get(context), types, count); + ptype = expr->return_type.InternalType(); + } +} + +void WindowInputColumn::Copy(DataChunk &input_chunk, idx_t input_idx) { + if (expr && (!input_idx || !scalar)) { + wtarget.Copy(input_chunk, input_idx); + } +} + //===--------------------------------------------------------------------===// // WindowColumnIterator //===--------------------------------------------------------------------===// @@ -155,11 +244,11 @@ struct WindowColumnIterator { return iterator(a.coll, a.pos + n); } - friend inline iterator &operator-(const iterator &a, difference_type n) { + friend inline iterator operator-(const iterator &a, difference_type n) { return iterator(a.coll, a.pos - n); } - friend inline iterator &operator+(difference_type n, const iterator &a) { + friend inline iterator operator+(difference_type n, const iterator &a) { return a + n; } friend inline difference_type operator-(const iterator &a, const iterator &b) { @@ -193,7 +282,7 @@ struct WindowColumnIterator { template struct OperationCompare : public std::function { inline bool operator()(const T &lhs, const T &val) const { - return OP::template Operation(lhs, val); + return OP::template Operation(lhs, val); } }; @@ -208,13 +297,13 @@ static idx_t FindTypedRangeBound(const WindowInputColumn &over, const idx_t orde // Check that the value we are searching for is in range. if (range == WindowBoundary::EXPR_PRECEDING_RANGE) { - // Preceding but value past the end - const auto cur_val = over.GetCell(order_end); + // Preceding but value past the current value + const auto cur_val = over.GetCell(order_end - 1); if (comp(cur_val, val)) { throw OutOfRangeException("Invalid RANGE PRECEDING value"); } } else { - // Following but value before beginning + // Following but value before the current value D_ASSERT(range == WindowBoundary::EXPR_FOLLOWING_RANGE); const auto cur_val = over.GetCell(order_begin); if (comp(val, cur_val)) { @@ -257,9 +346,9 @@ static idx_t FindRangeBound(const WindowInputColumn &over, const idx_t order_beg const WindowBoundary range, WindowInputExpression &boundary, const idx_t chunk_idx, const FrameBounds &prev) { D_ASSERT(boundary.chunk.ColumnCount() == 1); - D_ASSERT(boundary.chunk.data[0].GetType().InternalType() == over.input_expr.ptype); + D_ASSERT(boundary.chunk.data[0].GetType().InternalType() == over.ptype); - switch (over.input_expr.ptype) { + switch (over.ptype) { case PhysicalType::INT8: return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); case PhysicalType::INT16: @@ -309,7 +398,7 @@ static idx_t FindOrderedRangeBound(const WindowInputColumn &over, const OrderTyp struct WindowBoundariesState { static inline bool IsScalar(const unique_ptr &expr) { - return expr ? expr->IsScalar() : true; + return !expr || expr->IsScalar(); } static inline bool BoundaryNeedsPeer(const WindowBoundary &boundary) { @@ -335,7 +424,7 @@ struct WindowBoundariesState { } } - WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size); + WindowBoundariesState(const BoundWindowExpression &wexpr, const idx_t input_size); void Update(const idx_t row_idx, const WindowInputColumn &range_collection, const idx_t chunk_idx, WindowInputExpression &boundary_start, WindowInputExpression &boundary_end, @@ -364,8 +453,8 @@ struct WindowBoundariesState { idx_t peer_end = 0; idx_t valid_start = 0; idx_t valid_end = 0; - int64_t window_start = -1; - int64_t window_end = -1; + idx_t window_start = NumericLimits::Maximum(); + idx_t window_end = NumericLimits::Maximum(); FrameBounds prev; }; @@ -447,49 +536,53 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn next_pos = row_idx + 1; // determine window boundaries depending on the type of expression - window_start = -1; - window_end = -1; - switch (start_boundary) { case WindowBoundary::UNBOUNDED_PRECEDING: - window_start = NumericCast(partition_start); + window_start = partition_start; break; case WindowBoundary::CURRENT_ROW_ROWS: - window_start = NumericCast(row_idx); + window_start = row_idx; break; case WindowBoundary::CURRENT_ROW_RANGE: - window_start = NumericCast(peer_start); + window_start = peer_start; break; case WindowBoundary::EXPR_PRECEDING_ROWS: { - if (!TrySubtractOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), - window_start)) { - throw OutOfRangeException("Overflow computing ROWS PRECEDING start"); + int64_t computed_start; + if (!TrySubtractOperator::Operation(static_cast(row_idx), boundary_start.GetCell(chunk_idx), + computed_start)) { + window_start = partition_start; + } else { + window_start = UnsafeNumericCast(MaxValue(computed_start, 0)); } break; } case WindowBoundary::EXPR_FOLLOWING_ROWS: { - if (!TryAddOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), window_start)) { - throw OutOfRangeException("Overflow computing ROWS FOLLOWING start"); + int64_t computed_start; + if (!TryAddOperator::Operation(static_cast(row_idx), boundary_start.GetCell(chunk_idx), + computed_start)) { + window_start = partition_start; + } else { + window_start = UnsafeNumericCast(MaxValue(computed_start, 0)); } break; } case WindowBoundary::EXPR_PRECEDING_RANGE: { if (boundary_start.CellIsNull(chunk_idx)) { - window_start = NumericCast(peer_start); + window_start = peer_start; } else { - prev.start = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, + prev.start = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx + 1, start_boundary, boundary_start, chunk_idx, prev); - window_start = NumericCast(prev.start); + window_start = prev.start; } break; } case WindowBoundary::EXPR_FOLLOWING_RANGE: { if (boundary_start.CellIsNull(chunk_idx)) { - window_start = NumericCast(peer_start); + window_start = peer_start; } else { prev.start = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, start_boundary, boundary_start, chunk_idx, prev); - window_start = NumericCast(prev.start); + window_start = prev.start; } break; } @@ -499,42 +592,51 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn switch (end_boundary) { case WindowBoundary::CURRENT_ROW_ROWS: - window_end = NumericCast(row_idx + 1); + window_end = row_idx + 1; break; case WindowBoundary::CURRENT_ROW_RANGE: - window_end = NumericCast(peer_end); + window_end = peer_end; break; case WindowBoundary::UNBOUNDED_FOLLOWING: - window_end = NumericCast(partition_end); + window_end = partition_end; break; - case WindowBoundary::EXPR_PRECEDING_ROWS: + case WindowBoundary::EXPR_PRECEDING_ROWS: { + int64_t computed_start; if (!TrySubtractOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), - window_end)) { - throw OutOfRangeException("Overflow computing ROWS PRECEDING end"); + computed_start)) { + window_end = partition_end; + } else { + window_end = UnsafeNumericCast(MaxValue(computed_start, 0)); } break; - case WindowBoundary::EXPR_FOLLOWING_ROWS: - if (!TryAddOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), window_end)) { - throw OutOfRangeException("Overflow computing ROWS FOLLOWING end"); + } + case WindowBoundary::EXPR_FOLLOWING_ROWS: { + int64_t computed_start; + if (!TryAddOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), + computed_start)) { + window_end = partition_end; + } else { + window_end = UnsafeNumericCast(MaxValue(computed_start, 0)); } break; + } case WindowBoundary::EXPR_PRECEDING_RANGE: { if (boundary_end.CellIsNull(chunk_idx)) { - window_end = NumericCast(peer_end); + window_end = peer_end; } else { - prev.end = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, end_boundary, - boundary_end, chunk_idx, prev); - window_end = NumericCast(prev.end); + prev.end = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx + 1, + end_boundary, boundary_end, chunk_idx, prev); + window_end = prev.end; } break; } case WindowBoundary::EXPR_FOLLOWING_RANGE: { if (boundary_end.CellIsNull(chunk_idx)) { - window_end = NumericCast(peer_end); + window_end = peer_end; } else { prev.end = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, end_boundary, boundary_end, chunk_idx, prev); - window_end = NumericCast(prev.end); + window_end = prev.end; } break; } @@ -543,33 +645,29 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn } // clamp windows to partitions if they should exceed - if (window_start < NumericCast(partition_start)) { - window_start = NumericCast(partition_start); - } - if (window_start > NumericCast(partition_end)) { - window_start = NumericCast(partition_end); + if (window_start < partition_start) { + window_start = partition_start; } - if (window_end < NumericCast(partition_start)) { - window_end = NumericCast(partition_start); + if (window_start > partition_end) { + window_start = partition_end; } - if (window_end > NumericCast(partition_end)) { - window_end = NumericCast(partition_end); + if (window_end < partition_start) { + window_end = partition_start; } - - if (window_start < 0 || window_end < 0) { - throw InternalException("Failed to compute window boundaries"); + if (window_end > partition_end) { + window_end = partition_end; } } -static bool HasPrecedingRange(BoundWindowExpression &wexpr) { +static bool HasPrecedingRange(const BoundWindowExpression &wexpr) { return (wexpr.start == WindowBoundary::EXPR_PRECEDING_RANGE || wexpr.end == WindowBoundary::EXPR_PRECEDING_RANGE); } -static bool HasFollowingRange(BoundWindowExpression &wexpr) { +static bool HasFollowingRange(const BoundWindowExpression &wexpr) { return (wexpr.start == WindowBoundary::EXPR_FOLLOWING_RANGE || wexpr.end == WindowBoundary::EXPR_FOLLOWING_RANGE); } -WindowBoundariesState::WindowBoundariesState(BoundWindowExpression &wexpr, const idx_t input_size) +WindowBoundariesState::WindowBoundariesState(const BoundWindowExpression &wexpr, const idx_t input_size) : type(wexpr.type), input_size(input_size), start_boundary(wexpr.start), end_boundary(wexpr.end), partition_count(wexpr.partitions.size()), order_count(wexpr.orders.size()), range_sense(wexpr.orders.empty() ? OrderType::INVALID : wexpr.orders[0].type), @@ -598,8 +696,8 @@ void WindowBoundariesState::Bounds(DataChunk &bounds, idx_t row_idx, const Windo *peer_begin_data++ = peer_start; *peer_end_data++ = peer_end; } - *window_begin_data++ = window_start; - *window_end_data++ = window_end; + *window_begin_data++ = UnsafeNumericCast(window_start); + *window_end_data++ = UnsafeNumericCast(window_end); } bounds.SetCardinality(count); } @@ -607,10 +705,9 @@ void WindowBoundariesState::Bounds(DataChunk &bounds, idx_t row_idx, const Windo //===--------------------------------------------------------------------===// // WindowExecutorBoundsState //===--------------------------------------------------------------------===// -class WindowExecutorBoundsState : public WindowExecutorState { +class WindowExecutorBoundsState : public WindowExecutorLocalState { public: - WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t count, - const ValidityMask &partition_mask_p, const ValidityMask &order_mask_p); + explicit WindowExecutorBoundsState(const WindowExecutorGlobalState &gstate); ~WindowExecutorBoundsState() override { } @@ -627,13 +724,13 @@ class WindowExecutorBoundsState : public WindowExecutorState { WindowInputExpression boundary_end; }; -WindowExecutorBoundsState::WindowExecutorBoundsState(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask_p, - const ValidityMask &order_mask_p) - : partition_mask(partition_mask_p), order_mask(order_mask_p), state(wexpr, payload_count), - boundary_start(wexpr.start_expr.get(), context), boundary_end(wexpr.end_expr.get(), context) { +WindowExecutorBoundsState::WindowExecutorBoundsState(const WindowExecutorGlobalState &gstate) + : WindowExecutorLocalState(gstate), partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), + state(gstate.executor.wexpr, gstate.payload_count), + boundary_start(gstate.executor.wexpr.start_expr.get(), gstate.executor.context), + boundary_end(gstate.executor.wexpr.end_expr.get(), gstate.executor.context) { vector bounds_types(6, LogicalType(LogicalTypeId::UBIGINT)); - bounds.Initialize(Allocator::Get(context), bounds_types); + bounds.Initialize(Allocator::Get(gstate.executor.context), bounds_types); } void WindowExecutorBoundsState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { @@ -745,39 +842,10 @@ void ExclusionFilter::ResetMask(idx_t row_idx, idx_t offset) { } } -//===--------------------------------------------------------------------===// -// WindowValueState -//===--------------------------------------------------------------------===// - -//! A class representing the state of the first_value, last_value and nth_value functions -class WindowValueState : public WindowExecutorBoundsState { -public: - WindowValueState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t count, - const ValidityMask &partition_mask_p, const ValidityMask &order_mask_p, - const ValidityMask &ignore_nulls) - : WindowExecutorBoundsState(wexpr, context, count, partition_mask_p, order_mask_p) - - { - if (wexpr.exclude_clause == WindowExcludeMode::NO_OTHER) { - exclusion_filter = nullptr; - ignore_nulls_exclude = &ignore_nulls; - } else { - // create the exclusion filter based on ignore_nulls - exclusion_filter = make_uniq(wexpr.exclude_clause, count, ignore_nulls); - ignore_nulls_exclude = &exclusion_filter->mask; - } - } - - //! The exclusion filter handling exclusion - unique_ptr exclusion_filter; - //! The validity mask that combines both the NULLs and exclusion information - const ValidityMask *ignore_nulls_exclude; -}; - //===--------------------------------------------------------------------===// // WindowExecutor //===--------------------------------------------------------------------===// -static void PrepareInputExpressions(vector> &exprs, ExpressionExecutor &executor, +static void PrepareInputExpressions(const vector> &exprs, ExpressionExecutor &executor, DataChunk &chunk) { if (exprs.empty()) { return; @@ -795,31 +863,81 @@ static void PrepareInputExpressions(vector> &exprs, Expre } } -WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : wexpr(wexpr), context(context), payload_count(payload_count), partition_mask(partition_mask), - order_mask(order_mask), payload_collection(), payload_executor(context), - range((HasPrecedingRange(wexpr) || HasFollowingRange(wexpr)) ? wexpr.orders[0].expression.get() : nullptr, - context, payload_count) { +WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context) : wexpr(wexpr), context(context) { +} + +WindowExecutorGlobalState::WindowExecutorGlobalState(const WindowExecutor &executor, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : executor(executor), payload_count(payload_count), partition_mask(partition_mask), order_mask(order_mask), + range((HasPrecedingRange(executor.wexpr) || HasFollowingRange(executor.wexpr)) + ? executor.wexpr.orders[0].expression.get() + : nullptr, + executor.context, payload_count) { + for (const auto &child : executor.wexpr.children) { + arg_types.emplace_back(child->return_type); + } +} + +WindowExecutorLocalState::WindowExecutorLocalState(const WindowExecutorGlobalState &gstate) + : payload_executor(gstate.executor.context), range_executor(gstate.executor.context) { // TODO: child may be a scalar, don't need to materialize the whole collection then // evaluate inner expressions of window functions, could be more complex - PrepareInputExpressions(wexpr.children, payload_executor, payload_chunk); + PrepareInputExpressions(gstate.executor.wexpr.children, payload_executor, payload_chunk); - auto types = payload_chunk.GetTypes(); - if (!types.empty()) { - payload_collection.Initialize(Allocator::Get(context), types); + if (gstate.range.expr) { + vector types; + types.emplace_back(gstate.range.expr->return_type); + range_executor.AddExpression(*gstate.range.expr); + + auto &allocator = range_executor.GetAllocator(); + range_chunk.Initialize(allocator, types); + } +} + +void WindowExecutorLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &input_chunk, idx_t input_idx) { + if (gstate.range.expr && (!input_idx || !gstate.range.scalar)) { + range_executor.Execute(input_chunk, range_chunk); + gstate.range.Copy(range_chunk, input_idx); } } -unique_ptr WindowExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +unique_ptr WindowExecutor::GetGlobalState(const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { + return make_uniq(*this, payload_count, partition_mask, order_mask); +} + +unique_ptr WindowExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + return make_uniq(gstate); +} + +void WindowExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count, + WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { + lstate.Sink(gstate, input_chunk, input_idx); } //===--------------------------------------------------------------------===// // WindowAggregateExecutor //===--------------------------------------------------------------------===// -bool WindowAggregateExecutor::IsConstantAggregate() { +class WindowAggregateExecutorGlobalState : public WindowExecutorGlobalState { +public: + bool IsConstantAggregate(); + bool IsCustomAggregate(); + bool IsDistinctAggregate(); + + WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask); + + // aggregate computation algorithm + unique_ptr aggregator; + // aggregate global state + unique_ptr gsink; +}; + +bool WindowAggregateExecutorGlobalState::IsConstantAggregate() { + const auto &wexpr = executor.wexpr; + if (!wexpr.aggregate) { return false; } @@ -877,7 +995,9 @@ bool WindowAggregateExecutor::IsConstantAggregate() { return true; } -bool WindowAggregateExecutor::IsDistinctAggregate() { +bool WindowAggregateExecutorGlobalState::IsDistinctAggregate() { + const auto &wexpr = executor.wexpr; + if (!wexpr.aggregate) { return false; } @@ -885,7 +1005,10 @@ bool WindowAggregateExecutor::IsDistinctAggregate() { return wexpr.distinct; } -bool WindowAggregateExecutor::IsCustomAggregate() { +bool WindowAggregateExecutorGlobalState::IsCustomAggregate() { + const auto &wexpr = executor.wexpr; + const auto &mode = reinterpret_cast(executor).mode; + if (!wexpr.aggregate) { return false; } @@ -897,52 +1020,103 @@ bool WindowAggregateExecutor::IsCustomAggregate() { return (mode < WindowAggregationMode::COMBINE); } -void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &input_chunk, Vector &result, - WindowExecutorState &lstate) const { +void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &input_chunk, Vector &result, WindowExecutorLocalState &lstate, + WindowExecutorGlobalState &gstate) const { auto &lbstate = lstate.Cast(); - lbstate.UpdateBounds(row_idx, input_chunk, range); + lbstate.UpdateBounds(row_idx, input_chunk, gstate.range); const auto count = input_chunk.size(); - EvaluateInternal(lstate, result, count, row_idx); + EvaluateInternal(gstate, lstate, result, count, row_idx); result.Verify(count); } WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t count, const ValidityMask &partition_mask, - const ValidityMask &order_mask, WindowAggregationMode mode) - : WindowExecutor(wexpr, context, count, partition_mask, order_mask), mode(mode), filter_executor(context) { + WindowAggregationMode mode) + : WindowExecutor(wexpr, context), mode(mode) { +} + +WindowAggregateExecutorGlobalState::WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, + const idx_t group_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutorGlobalState(executor, group_count, partition_mask, order_mask) { + auto &wexpr = executor.wexpr; + auto &context = executor.context; + auto return_type = wexpr.return_type; + const auto &mode = reinterpret_cast(executor).mode; // Force naive for SEPARATE mode or for (currently!) unsupported functionality const auto force_naive = !ClientConfig::GetConfig(context).enable_optimizer || mode == WindowAggregationMode::SEPARATE; AggregateObject aggr(wexpr); if (force_naive || (wexpr.distinct && wexpr.exclude_clause != WindowExcludeMode::NO_OTHER)) { - aggregator = make_uniq(aggr, wexpr.return_type, wexpr.exclude_clause, count); + aggregator = make_uniq(aggr, arg_types, return_type, wexpr.exclude_clause); } else if (IsDistinctAggregate()) { // build a merge sort tree // see https://dl.acm.org/doi/pdf/10.1145/3514221.3526184 - aggregator = make_uniq(aggr, wexpr.return_type, wexpr.exclude_clause, count, context); + aggregator = make_uniq(aggr, arg_types, return_type, wexpr.exclude_clause, context); } else if (IsConstantAggregate()) { - aggregator = - make_uniq(aggr, wexpr.return_type, partition_mask, wexpr.exclude_clause, count); + aggregator = make_uniq(aggr, arg_types, return_type, wexpr.exclude_clause); } else if (IsCustomAggregate()) { - aggregator = make_uniq(aggr, wexpr.return_type, wexpr.exclude_clause, count); + aggregator = make_uniq(aggr, arg_types, return_type, wexpr.exclude_clause); } else { // build a segment tree for frame-adhering aggregates // see http://www.vldb.org/pvldb/vol8/p1058-leis.pdf - aggregator = make_uniq(aggr, wexpr.return_type, mode, wexpr.exclude_clause, count); + aggregator = make_uniq(aggr, arg_types, return_type, mode, wexpr.exclude_clause); } - // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse - if (wexpr.filter_expr) { - filter_executor.AddExpression(*wexpr.filter_expr); - filter_sel.Initialize(STANDARD_VECTOR_SIZE); + gsink = aggregator->GetGlobalState(group_count, partition_mask); +} + +unique_ptr WindowAggregateExecutor::GetGlobalState(const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { + return make_uniq(*this, payload_count, partition_mask, order_mask); +} + +class WindowAggregateExecutorLocalState : public WindowExecutorBoundsState { +public: + WindowAggregateExecutorLocalState(const WindowExecutorGlobalState &gstate, const WindowAggregator &aggregator) + : WindowExecutorBoundsState(gstate), filter_executor(gstate.executor.context) { + + auto &gastate = gstate.Cast(); + aggregator_state = aggregator.GetLocalState(*gastate.gsink); + + // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse + auto &wexpr = gstate.executor.wexpr; + if (wexpr.filter_expr) { + filter_executor.AddExpression(*wexpr.filter_expr); + filter_sel.Initialize(STANDARD_VECTOR_SIZE); + } } + +public: + // state of aggregator + unique_ptr aggregator_state; + //! Executor for any filter clause + ExpressionExecutor filter_executor; + //! Result of filtering + SelectionVector filter_sel; +}; + +unique_ptr +WindowAggregateExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + auto &gastate = gstate.Cast(); + auto res = make_uniq(gstate, *gastate.aggregator); + return std::move(res); } -void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { - // TODO we could evaluate those expressions in parallel +void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count, + WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { + auto &gastate = gstate.Cast(); + auto &lastate = lstate.Cast(); + auto &filter_sel = lastate.filter_sel; + auto &filter_executor = lastate.filter_executor; + auto &payload_executor = lastate.payload_executor; + auto &payload_chunk = lastate.payload_chunk; + auto &aggregator = gastate.aggregator; + idx_t filtered = 0; SelectionVector *filtering = nullptr; if (wexpr.filter_expr) { @@ -960,9 +1134,11 @@ void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx } D_ASSERT(aggregator); - aggregator->Sink(payload_chunk, filtering, filtered); + auto &gestate = *gastate.gsink; + auto &lestate = *lastate.aggregator_state; + aggregator->Sink(gestate, lestate, payload_chunk, input_idx, filtering, filtered); - WindowExecutor::Sink(input_chunk, input_idx, total_count); + WindowExecutor::Sink(input_chunk, input_idx, total_count, gstate, lstate); } static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, BaseStatistics *base, bool is_start) { @@ -1023,13 +1199,16 @@ static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, } } -void WindowAggregateExecutor::Finalize() { +void WindowAggregateExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { + auto &gastate = gstate.Cast(); + auto &aggregator = gastate.aggregator; + auto &gsink = gastate.gsink; D_ASSERT(aggregator); // Estimate the frame statistics // Default to the entire partition if we don't know anything FrameStats stats; - const auto count = NumericCast(aggregator->GetInputs().size()); + const auto count = NumericCast(gastate.payload_count); // First entry is the frame start stats[0] = FrameDelta(-count, count); @@ -1041,51 +1220,32 @@ void WindowAggregateExecutor::Finalize() { base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[1].get(); ApplyWindowStats(wexpr.end, stats[1], base, false); - aggregator->Finalize(stats); -} - -class WindowAggregateState : public WindowExecutorBoundsState { -public: - WindowAggregateState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask, - const WindowAggregator &aggregator) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), - aggregator_state(aggregator.GetLocalState()) { - } - -public: - // state of aggregator - unique_ptr aggregator_state; - - void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); -}; - -unique_ptr WindowAggregateExecutor::GetExecutorState() const { - auto res = make_uniq(wexpr, context, payload_count, partition_mask, order_mask, *aggregator); - return std::move(res); + auto &lastate = lstate.Cast(); + aggregator->Finalize(*gsink, *lastate.aggregator_state, stats); } -void WindowAggregateExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lastate = lstate.Cast(); +void WindowAggregateExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gastate = gstate.Cast(); + auto &lastate = lstate.Cast(); + auto &aggregator = gastate.aggregator; + auto &gsink = gastate.gsink; D_ASSERT(aggregator); auto &agg_state = *lastate.aggregator_state; - aggregator->Evaluate(agg_state, lastate.bounds, result, count, row_idx); + aggregator->Evaluate(*gsink, agg_state, lastate.bounds, result, count, row_idx); } //===--------------------------------------------------------------------===// // WindowRowNumberExecutor //===--------------------------------------------------------------------===// -WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { } -void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { auto &lbstate = lstate.Cast(); auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); auto rdata = FlatVector::GetData(result); @@ -1099,9 +1259,7 @@ void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorState &lstate, Vect //===--------------------------------------------------------------------===// class WindowPeerState : public WindowExecutorBoundsState { public: - WindowPeerState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask) { + explicit WindowPeerState(const WindowExecutorGlobalState &gstate) : WindowExecutorBoundsState(gstate) { } public: @@ -1125,17 +1283,16 @@ void WindowPeerState::NextRank(idx_t partition_begin, idx_t peer_begin, idx_t ro rank_equal++; } -WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { } -unique_ptr WindowRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +unique_ptr WindowRankExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + return make_uniq(gstate); } -void WindowRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { auto &lpeer = lstate.Cast(); auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); @@ -1151,19 +1308,20 @@ void WindowRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &r } } -WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { } -unique_ptr WindowDenseRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +unique_ptr +WindowDenseRankExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + return make_uniq(gstate); } -void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { auto &lpeer = lstate.Cast(); + + auto &order_mask = gstate.order_mask; auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); auto rdata = FlatVector::GetData(result); @@ -1213,18 +1371,17 @@ void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vect } } -WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { } -unique_ptr WindowPercentRankExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +unique_ptr +WindowPercentRankExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + return make_uniq(gstate); } -void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { auto &lpeer = lstate.Cast(); auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(lpeer.bounds.data[PARTITION_END]); @@ -1237,7 +1394,7 @@ void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Ve for (idx_t i = 0; i < count; ++i, ++row_idx) { lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - auto denom = NumericCast(partition_end[i] - partition_begin[i] - 1); + auto denom = static_cast(NumericCast(partition_end[i] - partition_begin[i] - 1)); double percent_rank = denom > 0 ? ((double)lpeer.rank - 1) / denom : 0; rdata[i] = percent_rank; } @@ -1246,111 +1403,152 @@ void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Ve //===--------------------------------------------------------------------===// // WindowCumeDistExecutor //===--------------------------------------------------------------------===// -WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { } -void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { auto &lbstate = lstate.Cast(); auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(lbstate.bounds.data[PARTITION_END]); auto peer_end = FlatVector::GetData(lbstate.bounds.data[PEER_END]); auto rdata = FlatVector::GetData(result); for (idx_t i = 0; i < count; ++i, ++row_idx) { - auto denom = NumericCast(partition_end[i] - partition_begin[i]); + auto denom = static_cast(NumericCast(partition_end[i] - partition_begin[i])); double cume_dist = denom > 0 ? ((double)(peer_end[i] - partition_begin[i])) / denom : 0; rdata[i] = cume_dist; } } //===--------------------------------------------------------------------===// -// WindowValueExecutor +// WindowValueGlobalState //===--------------------------------------------------------------------===// -WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { -} - -void WindowValueExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { - // Single pass over the input to produce the global data. - // Vectorisation for the win... - - // Set up a validity mask for IGNORE NULLS - bool check_nulls = false; - if (wexpr.ignore_nulls) { - switch (wexpr.type) { - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_NTH_VALUE: - check_nulls = true; - break; - default: - break; + +class WindowValueGlobalState : public WindowExecutorGlobalState { +public: + WindowValueGlobalState(const WindowExecutor &executor, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) + : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), + payload_collection(payload_data), ignore_nulls(&no_nulls) + + { + if (!arg_types.empty()) { + payload_collection.Initialize(Allocator::Get(executor.context), arg_types, payload_count); } + + auto &wexpr = executor.wexpr; + if (wexpr.ignore_nulls) { + switch (wexpr.type) { + case ExpressionType::WINDOW_LEAD: + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_FIRST_VALUE: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_NTH_VALUE: + ignore_nulls = &FlatVector::Validity(payload_collection.chunk.data[0]); + break; + default: + break; + } + } + } + + // The partition values + DataChunk payload_data; + // The partition values + WindowDataChunk payload_collection; + // Mask to use for exclusion if we are not ignoring NULLs + ValidityMask no_nulls; + // IGNORE NULLS + optional_ptr ignore_nulls; +}; + +//===--------------------------------------------------------------------===// +// WindowValueLocalState +//===--------------------------------------------------------------------===// + +//! A class representing the state of the first_value, last_value and nth_value functions +class WindowValueLocalState : public WindowExecutorBoundsState { +public: + explicit WindowValueLocalState(const WindowValueGlobalState &gvstate) + : WindowExecutorBoundsState(gvstate), gvstate(gvstate) { } + //! Lazily initialize for value Execute + void Initialize(); + + //! The corresponding global value state + const WindowValueGlobalState &gvstate; + //! Lazy initialization flag + bool initialized = false; + //! The exclusion filter handler + unique_ptr exclusion_filter; + //! The validity mask that combines both the NULLs and exclusion information + optional_ptr ignore_nulls_exclude; +}; + +void WindowValueLocalState::Initialize() { + if (initialized) { + return; + } + auto ignore_nulls = gvstate.ignore_nulls; + if (gvstate.executor.wexpr.exclude_clause == WindowExcludeMode::NO_OTHER) { + exclusion_filter = nullptr; + ignore_nulls_exclude = ignore_nulls; + } else { + // create the exclusion filter based on ignore_nulls + exclusion_filter = + make_uniq(gvstate.executor.wexpr.exclude_clause, gvstate.payload_count, *ignore_nulls); + ignore_nulls_exclude = &exclusion_filter->mask; + } + + initialized = true; +} + +//===--------------------------------------------------------------------===// +// WindowValueExecutor +//===--------------------------------------------------------------------===// +WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowExecutor(wexpr, context) { +} + +WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowValueExecutor(wexpr, context) { +} + +unique_ptr WindowValueExecutor::GetGlobalState(const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const { + return make_uniq(*this, payload_count, partition_mask, order_mask); +} + +void WindowValueExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count, + WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { + auto &gvstate = gstate.Cast(); + auto &lvstate = lstate.Cast(); + auto &payload_chunk = lvstate.payload_chunk; + auto &payload_executor = lvstate.payload_executor; + auto &payload_collection = gvstate.payload_collection; + if (!wexpr.children.empty()) { payload_chunk.Reset(); payload_executor.Execute(input_chunk, payload_chunk); payload_chunk.Verify(); - payload_collection.Append(payload_chunk, true); - - // process payload chunks while they are still piping hot - if (check_nulls) { - const auto count = input_chunk.size(); - - payload_chunk.Flatten(); - UnifiedVectorFormat vdata; - payload_chunk.data[0].ToUnifiedFormat(count, vdata); - if (!vdata.validity.AllValid()) { - // Lazily materialise the contents when we find the first NULL - if (ignore_nulls.AllValid()) { - ignore_nulls.Initialize(total_count); - } - // Write to the current position - if (input_idx % ValidityMask::BITS_PER_VALUE == 0) { - // If we are at the edge of an output entry, just copy the entries - auto dst = ignore_nulls.GetData() + ignore_nulls.EntryCount(input_idx); - auto src = vdata.validity.GetData(); - for (auto entry_count = vdata.validity.EntryCount(count); entry_count-- > 0;) { - *dst++ = *src++; - } - } else { - // If not, we have ragged data and need to copy one bit at a time. - for (idx_t i = 0; i < count; ++i) { - ignore_nulls.Set(input_idx + i, vdata.validity.RowIsValid(i)); - } - } - } - } + payload_collection.Copy(payload_chunk, input_idx); } - WindowExecutor::Sink(input_chunk, input_idx, total_count); + WindowExecutor::Sink(input_chunk, input_idx, total_count, gstate, lstate); } -unique_ptr WindowValueExecutor::GetExecutorState() const { - if (wexpr.type == ExpressionType::WINDOW_FIRST_VALUE || wexpr.type == ExpressionType::WINDOW_LAST_VALUE || - wexpr.type == ExpressionType::WINDOW_NTH_VALUE) { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask, ignore_nulls); - } else { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); - } +unique_ptr WindowValueExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + const auto &gvstate = gstate.Cast(); + return make_uniq(gvstate); } -void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gvstate = gstate.Cast(); + auto &payload_collection = gvstate.payload_collection.chunk; D_ASSERT(payload_collection.ColumnCount() == 1); auto &lbstate = lstate.Cast(); auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); @@ -1395,14 +1593,14 @@ void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector & } //===--------------------------------------------------------------------===// -// WindowLeadLagState +// WindowLeadLagLocalState //===--------------------------------------------------------------------===// -class WindowLeadLagState : public WindowExecutorBoundsState { +class WindowLeadLagLocalState : public WindowValueLocalState { public: - WindowLeadLagState(BoundWindowExpression &wexpr, ClientContext &context, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorBoundsState(wexpr, context, payload_count, partition_mask, order_mask), - leadlag_offset(wexpr.offset_expr.get(), context), leadlag_default(wexpr.default_expr.get(), context) { + explicit WindowLeadLagLocalState(const WindowValueGlobalState &gstate) + : WindowValueLocalState(gstate), + leadlag_offset(gstate.executor.wexpr.offset_expr.get(), gstate.executor.context), + leadlag_default(gstate.executor.wexpr.default_expr.get(), gstate.executor.context) { } void UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) override; @@ -1413,7 +1611,7 @@ class WindowLeadLagState : public WindowExecutorBoundsState { WindowInputExpression leadlag_default; }; -void WindowLeadLagState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { +void WindowLeadLagLocalState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, const WindowInputColumn &range) { // Evaluate the row-level arguments leadlag_offset.Execute(input_chunk); leadlag_default.Execute(input_chunk); @@ -1421,23 +1619,35 @@ void WindowLeadLagState::UpdateBounds(idx_t row_idx, DataChunk &input_chunk, con WindowExecutorBoundsState::UpdateBounds(row_idx, input_chunk, range); } -WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowValueExecutor(wexpr, context) { } -unique_ptr WindowLeadLagExecutor::GetExecutorState() const { - return make_uniq(wexpr, context, payload_count, partition_mask, order_mask); +unique_ptr +WindowLeadLagExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { + const auto &gvstate = gstate.Cast(); + return make_uniq(gvstate); } -void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &llstate = lstate.Cast(); +void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gvstate = gstate.Cast(); + auto &payload_collection = gvstate.payload_collection.chunk; + auto &ignore_nulls = gvstate.ignore_nulls; + auto &llstate = lstate.Cast(); + + bool can_shift = ignore_nulls->AllValid(); + if (wexpr.offset_expr) { + can_shift = can_shift && wexpr.offset_expr->IsFoldable(); + } + if (wexpr.default_expr) { + can_shift = can_shift && wexpr.default_expr->IsFoldable(); + } auto partition_begin = FlatVector::GetData(llstate.bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(llstate.bounds.data[PARTITION_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { + const auto row_end = row_idx + count; + for (idx_t i = 0; i < count;) { int64_t offset = 1; if (wexpr.offset_expr) { offset = llstate.leadlag_offset.GetCell(i); @@ -1453,32 +1663,58 @@ void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector if (val_idx < (int64_t)row_idx) { // Count backwards delta = idx_t(row_idx - idx_t(val_idx)); - val_idx = int64_t(FindPrevStart(ignore_nulls, partition_begin[i], row_idx, delta)); + val_idx = int64_t(FindPrevStart(*ignore_nulls, partition_begin[i], row_idx, delta)); } else if (val_idx > (int64_t)row_idx) { delta = idx_t(idx_t(val_idx) - row_idx); - val_idx = int64_t(FindNextStart(ignore_nulls, row_idx + 1, partition_end[i], delta)); + val_idx = int64_t(FindNextStart(*ignore_nulls, row_idx + 1, partition_end[i], delta)); } // else offset is zero, so don't move. - if (!delta) { - CopyCell(payload_collection, 0, NumericCast(val_idx), result, i); - } else if (wexpr.default_expr) { - llstate.leadlag_default.CopyCell(result, i); + if (can_shift) { + if (!delta) { + // Copy source[index:index+width] => result[i:] + const auto index = NumericCast(val_idx); + const auto source_limit = partition_end[i] - index; + const auto target_limit = MinValue(partition_end[i], row_end) - row_idx; + const auto width = MinValue(source_limit, target_limit); + auto &source = payload_collection.data[0]; + VectorOperations::Copy(source, result, index + width, index, i); + i += width; + row_idx += width; + } else if (wexpr.default_expr) { + const auto width = MinValue(delta, count - i); + llstate.leadlag_default.CopyCell(result, i, width); + i += width; + row_idx += width; + } else { + for (idx_t nulls = MinValue(delta, count - i); nulls--; ++i, ++row_idx) { + FlatVector::SetNull(result, i, true); + } + } } else { - FlatVector::SetNull(result, i, true); + if (!delta) { + CopyCell(payload_collection, 0, NumericCast(val_idx), result, i); + } else if (wexpr.default_expr) { + llstate.leadlag_default.CopyCell(result, i); + } else { + FlatVector::SetNull(result, i, true); + } + ++i; + ++row_idx; } } } -WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowValueExecutor(wexpr, context) { } -void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lvstate = lstate.Cast(); +void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gvstate = gstate.Cast(); + auto &payload_collection = gvstate.payload_collection.chunk; + auto &lvstate = lstate.Cast(); + lvstate.Initialize(); auto window_begin = FlatVector::GetData(lvstate.bounds.data[WINDOW_BEGIN]); auto window_end = FlatVector::GetData(lvstate.bounds.data[WINDOW_END]); for (idx_t i = 0; i < count; ++i, ++row_idx) { @@ -1506,15 +1742,16 @@ void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vec } } -WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowValueExecutor(wexpr, context) { } -void WindowLastValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lvstate = lstate.Cast(); +void WindowLastValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gvstate = gstate.Cast(); + auto &payload_collection = gvstate.payload_collection.chunk; + auto &lvstate = lstate.Cast(); + lvstate.Initialize(); auto window_begin = FlatVector::GetData(lvstate.bounds.data[WINDOW_BEGIN]); auto window_end = FlatVector::GetData(lvstate.bounds.data[WINDOW_END]); for (idx_t i = 0; i < count; ++i, ++row_idx) { @@ -1541,17 +1778,18 @@ void WindowLastValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vect } } -WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - const idx_t payload_count, const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowValueExecutor(wexpr, context, payload_count, partition_mask, order_mask) { +WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context) + : WindowValueExecutor(wexpr, context) { } -void WindowNthValueExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowNthValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + Vector &result, idx_t count, idx_t row_idx) const { + auto &gvstate = gstate.Cast(); + auto &payload_collection = gvstate.payload_collection.chunk; D_ASSERT(payload_collection.ColumnCount() == 2); - auto &lvstate = lstate.Cast(); + auto &lvstate = lstate.Cast(); + lvstate.Initialize(); auto window_begin = FlatVector::GetData(lvstate.bounds.data[WINDOW_BEGIN]); auto window_end = FlatVector::GetData(lvstate.bounds.data[WINDOW_END]); for (idx_t i = 0; i < count; ++i, ++row_idx) { diff --git a/src/duckdb/src/execution/window_segment_tree.cpp b/src/duckdb/src/execution/window_segment_tree.cpp index 906a62e9..be71e85c 100644 --- a/src/duckdb/src/execution/window_segment_tree.cpp +++ b/src/duckdb/src/execution/window_segment_tree.cpp @@ -2,12 +2,14 @@ #include "duckdb/common/algorithm.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/sort/partition_state.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/merge_sort_tree.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/execution/window_executor.hpp" #include +#include #include namespace duckdb { @@ -18,50 +20,201 @@ namespace duckdb { WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAllocator()) { } -WindowAggregator::WindowAggregator(AggregateObject aggr_p, const LogicalType &result_type_p, - const WindowExcludeMode exclude_mode_p, idx_t partition_count_p) - : aggr(std::move(aggr_p)), result_type(result_type_p), partition_count(partition_count_p), - state_size(aggr.function.state_size()), filter_pos(0), exclude_mode(exclude_mode_p) { +class WindowAggregatorGlobalState : public WindowAggregatorState { +public: + WindowAggregatorGlobalState(const WindowAggregator &aggregator_p, idx_t group_count) + : aggregator(aggregator_p), winputs(inputs), locals(0), finalized(0) { + + if (!aggregator.arg_types.empty()) { + winputs.Initialize(Allocator::DefaultAllocator(), aggregator.arg_types, group_count); + } + if (aggregator.aggr.filter) { + // Start with all invalid and set the ones that pass + filter_mask.Initialize(group_count, false); + } + } + + //! The aggregator data + const WindowAggregator &aggregator; + + //! Partition data chunk + DataChunk inputs; + WindowDataChunk winputs; + + //! The filtered rows in inputs. + ValidityArray filter_mask; + + //! Lock for single threading + mutable mutex lock; + + //! Count of local tasks + mutable std::atomic locals; + + //! Number of finalised states + std::atomic finalized; +}; + +WindowAggregator::WindowAggregator(AggregateObject aggr_p, const vector &arg_types_p, + const LogicalType &result_type_p, const WindowExcludeMode exclude_mode_p) + : aggr(std::move(aggr_p)), arg_types(arg_types_p), result_type(result_type_p), + state_size(aggr.function.state_size(aggr.function)), exclude_mode(exclude_mode_p) { } WindowAggregator::~WindowAggregator() { } -void WindowAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { - if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { - inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); - } - if (inputs.ColumnCount()) { - inputs.Append(payload_chunk, true); +unique_ptr WindowAggregator::GetGlobalState(idx_t group_count, const ValidityMask &) const { + return make_uniq(*this, group_count); +} + +void WindowAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &arg_chunk, + idx_t input_idx, optional_ptr filter_sel, idx_t filtered) { + auto &gasink = gsink.Cast(); + auto &winputs = gasink.winputs; + auto &filter_mask = gasink.filter_mask; + if (winputs.chunk.ColumnCount()) { + winputs.Copy(arg_chunk, input_idx); } if (filter_sel) { - // Lazy instantiation - if (!filter_mask.IsMaskSet()) { - // Start with all invalid and set the ones that pass - filter_bits.resize(ValidityMask::ValidityMaskSize(partition_count), 0); - filter_mask.Initialize(filter_bits.data()); - } for (idx_t f = 0; f < filtered; ++f) { - filter_mask.SetValid(filter_pos + filter_sel->get_index(f)); + filter_mask.SetValid(input_idx + filter_sel->get_index(f)); } - filter_pos += payload_chunk.size(); } } -void WindowAggregator::Finalize(const FrameStats &stats) { +void WindowAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, const FrameStats &stats) { } //===--------------------------------------------------------------------===// -// WindowConstantAggregate +// WindowConstantAggregator //===--------------------------------------------------------------------===// -WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const LogicalType &result_type, - const ValidityMask &partition_mask, - const WindowExcludeMode exclude_mode_p, const idx_t count) - : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, count), partition(0), row(0), state(state_size), - statep(Value::POINTER(CastPointerToValue(state.data()))), - statef(Value::POINTER(CastPointerToValue(state.data()))) { +struct WindowAggregateStates { + explicit WindowAggregateStates(const AggregateObject &aggr); + ~WindowAggregateStates() { + Destroy(); + } - statef.SetVectorType(VectorType::FLAT_VECTOR); // Prevent conversion of results to constants + //! The number of states + idx_t GetCount() const { + return states.size() / state_size; + } + data_ptr_t *GetData() { + return FlatVector::GetData(*statef); + } + data_ptr_t GetStatePtr(idx_t idx) { + return states.data() + idx * state_size; + } + const_data_ptr_t GetStatePtr(idx_t idx) const { + return states.data() + idx * state_size; + } + //! Initialise all the states + void Initialize(idx_t count); + //! Combine the states into the target + void Combine(WindowAggregateStates &target, + AggregateCombineType combine_type = AggregateCombineType::PRESERVE_INPUT); + //! Finalize the states into an output vector + void Finalize(Vector &result); + //! Destroy the states + void Destroy(); + + //! A description of the aggregator + const AggregateObject aggr; + //! The size of each state + const idx_t state_size; + //! The allocator to use + ArenaAllocator allocator; + //! Data pointer that contains the state data + vector states; + //! Reused result state container for the window functions + unique_ptr statef; +}; + +WindowAggregateStates::WindowAggregateStates(const AggregateObject &aggr) + : aggr(aggr), state_size(aggr.function.state_size(aggr.function)), allocator(Allocator::DefaultAllocator()) { +} + +void WindowAggregateStates::Initialize(idx_t count) { + states.resize(count * state_size); + auto state_ptr = states.data(); + + statef = make_uniq(LogicalType::POINTER, count); + auto state_f_data = FlatVector::GetData(*statef); + + for (idx_t i = 0; i < count; ++i, state_ptr += state_size) { + state_f_data[i] = state_ptr; + aggr.function.initialize(aggr.function, state_ptr); + } + + // Prevent conversion of results to constants + statef->SetVectorType(VectorType::FLAT_VECTOR); +} + +void WindowAggregateStates::Combine(WindowAggregateStates &target, AggregateCombineType combine_type) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); + aggr.function.combine(*statef, *target.statef, aggr_input_data, GetCount()); +} + +void WindowAggregateStates::Finalize(Vector &result) { + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + aggr.function.finalize(*statef, aggr_input_data, result, GetCount(), 0); +} + +void WindowAggregateStates::Destroy() { + if (states.empty()) { + return; + } + + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + if (aggr.function.destructor) { + aggr.function.destructor(*statef, aggr_input_data, GetCount()); + } + + states.clear(); +} + +class WindowConstantAggregatorGlobalState : public WindowAggregatorGlobalState { +public: + WindowConstantAggregatorGlobalState(const WindowConstantAggregator &aggregator, idx_t count, + const ValidityMask &partition_mask); + + void Finalize(const FrameStats &stats); + + //! Partition starts + vector partition_offsets; + //! Reused result state container for the window functions + WindowAggregateStates statef; + //! Aggregate results + unique_ptr results; +}; + +class WindowConstantAggregatorLocalState : public WindowAggregatorState { +public: + explicit WindowConstantAggregatorLocalState(const WindowConstantAggregatorGlobalState &gstate); + ~WindowConstantAggregatorLocalState() override { + } + + void Sink(DataChunk &payload_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered); + void Combine(WindowConstantAggregatorGlobalState &gstate); + +public: + //! The global state we are sharing + const WindowConstantAggregatorGlobalState &gstate; + //! Reusable chunk for sinking + DataChunk inputs; + //! A vector of pointers to "state", used for intermediate window segment aggregation + Vector statep; + //! Reused result state container for the window functions + WindowAggregateStates statef; + //! The current result partition being read + idx_t partition; + //! Shared SV for evaluation + SelectionVector matches; +}; + +WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(const WindowConstantAggregator &aggregator, + idx_t group_count, + const ValidityMask &partition_mask) + : WindowAggregatorGlobalState(aggregator, STANDARD_VECTOR_SIZE), statef(aggregator.aggr) { // Locate the partition boundaries if (partition_mask.AllValid()) { @@ -69,7 +222,7 @@ WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const L } else { idx_t entry_idx; idx_t shift; - for (idx_t start = 0; start < count;) { + for (idx_t start = 0; start < group_count;) { partition_mask.GetEntryIndex(start, entry_idx, shift); // If start is aligned with the start of a block, @@ -81,7 +234,7 @@ WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const L } // Loop over the block - for (; shift < ValidityMask::BITS_PER_VALUE && start < count; ++shift, ++start) { + for (; shift < ValidityMask::BITS_PER_VALUE && start < group_count; ++shift, ++start) { if (partition_mask.RowIsValid(block, shift)) { partition_offsets.emplace_back(start); } @@ -90,45 +243,70 @@ WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const L } // Initialise the vector for caching the results - results = make_uniq(result_type, partition_offsets.size()); - partition_offsets.emplace_back(count); + results = make_uniq(aggregator.result_type, partition_offsets.size()); - // Create an aggregate state for intermediate aggregates - gstate = make_uniq(); + // Initialise the final states + statef.Initialize(partition_offsets.size()); - // Start the first aggregate - AggregateInit(); + // Add final guard + partition_offsets.emplace_back(group_count); } -void WindowConstantAggregator::AggregateInit() { - aggr.function.initialize(state.data()); +WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( + const WindowConstantAggregatorGlobalState &gstate) + : gstate(gstate), statep(Value::POINTER(0)), statef(gstate.statef.aggr), partition(0) { + matches.Initialize(); + + // Start the aggregates + auto &partition_offsets = gstate.partition_offsets; + auto &aggregator = gstate.aggregator; + statef.Initialize(partition_offsets.size() - 1); + + // Set up shared buffer + inputs.Initialize(Allocator::DefaultAllocator(), aggregator.arg_types); + + gstate.locals++; } -void WindowConstantAggregator::AggegateFinal(Vector &result, idx_t rid) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); - aggr.function.finalize(statef, aggr_input_data, result, 1, rid); +WindowConstantAggregator::WindowConstantAggregator(AggregateObject aggr, const vector &arg_types, + const LogicalType &result_type, + const WindowExcludeMode exclude_mode_p) + : WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode_p) { +} - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, 1); - } +unique_ptr WindowConstantAggregator::GetGlobalState(idx_t group_count, + const ValidityMask &partition_mask) const { + return make_uniq(*this, group_count, partition_mask); } -void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { +void WindowConstantAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &arg_chunk, + idx_t input_idx, optional_ptr filter_sel, idx_t filtered) { + auto &lastate = lstate.Cast(); + + lastate.Sink(arg_chunk, input_idx, filter_sel, filtered); +} + +void WindowConstantAggregatorLocalState::Sink(DataChunk &payload_chunk, idx_t row, + optional_ptr filter_sel, idx_t filtered) { + auto &partition_offsets = gstate.partition_offsets; + auto &aggregator = gstate.aggregator; + const auto &aggr = aggregator.aggr; const auto chunk_begin = row; const auto chunk_end = chunk_begin + payload_chunk.size(); + idx_t partition = + idx_t(std::upper_bound(partition_offsets.begin(), partition_offsets.end(), row) - partition_offsets.begin()) - + 1; - if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { - inputs.Initialize(Allocator::DefaultAllocator(), payload_chunk.GetTypes()); - } + auto state_f_data = statef.GetData(); + auto state_p_data = FlatVector::GetData(statep); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); idx_t begin = 0; idx_t filter_idx = 0; auto partition_end = partition_offsets[partition + 1]; while (row < chunk_end) { if (row == partition_end) { - AggegateFinal(*results, partition++); - AggregateInit(); + ++partition; partition_end = partition_offsets[partition + 1]; } partition_end = MinValue(partition_end, chunk_end); @@ -174,9 +352,11 @@ void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *f // Aggregate the filtered rows into a single state const auto count = inputs.size(); + auto state = state_f_data[partition]; if (aggr.function.simple_update) { - aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state.data(), count); + aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state, count); } else { + state_p_data[0] = state_f_data[partition]; aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); } @@ -186,34 +366,36 @@ void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *f } } -void WindowConstantAggregator::Finalize(const FrameStats &stats) { - AggegateFinal(*results, partition++); -} +void WindowConstantAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, + const FrameStats &stats) { + auto &gastate = gstate.Cast(); + auto &lastate = lstate.Cast(); -class WindowConstantAggregatorState : public WindowAggregatorState { -public: - WindowConstantAggregatorState() : partition(0) { - matches.Initialize(); - } - ~WindowConstantAggregatorState() override { - } + // Single-threaded combine + lock_guard finalize_guard(gastate.lock); + lastate.statef.Combine(gastate.statef); + lastate.statef.Destroy(); -public: - //! The current result partition being read - idx_t partition; - //! Shared SV for evaluation - SelectionVector matches; -}; + // Last one out turns off the lights! + if (++gastate.finalized == gastate.locals) { + gastate.statef.Finalize(*gastate.results); + gastate.statef.Destroy(); + } +} -unique_ptr WindowConstantAggregator::GetLocalState() const { - return make_uniq(); +unique_ptr WindowConstantAggregator::GetLocalState(const WindowAggregatorState &gstate) const { + return make_uniq(gstate.Cast()); } -void WindowConstantAggregator::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &target, - idx_t count, idx_t row_idx) const { +void WindowConstantAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { + auto &gasink = gsink.Cast(); + const auto &partition_offsets = gasink.partition_offsets; + const auto &results = *gasink.results; + auto begins = FlatVector::GetData(bounds.data[WINDOW_BEGIN]); // Chunk up the constants and copy them one at a time - auto &lcstate = lstate.Cast(); + auto &lcstate = lstate.Cast(); idx_t matched = 0; idx_t target_offset = 0; for (idx_t i = 0; i < count; ++i) { @@ -222,7 +404,7 @@ void WindowConstantAggregator::Evaluate(WindowAggregatorState &lstate, const Dat while (partition_offsets[lcstate.partition + 1] <= begin) { // Flush the previous partition's data if (matched) { - VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); + VectorOperations::Copy(results, result, lcstate.matches, matched, 0, target_offset); target_offset += matched; matched = 0; } @@ -234,16 +416,22 @@ void WindowConstantAggregator::Evaluate(WindowAggregatorState &lstate, const Dat // Flush the last partition if (matched) { - VectorOperations::Copy(*results, target, lcstate.matches, matched, 0, target_offset); + // Optimize constant result + if (target_offset == 0 && matched == count) { + VectorOperations::Copy(results, result, lcstate.matches, 1, 0, target_offset); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } else { + VectorOperations::Copy(results, result, lcstate.matches, matched, 0, target_offset); + } } } //===--------------------------------------------------------------------===// // WindowCustomAggregator //===--------------------------------------------------------------------===// -WindowCustomAggregator::WindowCustomAggregator(AggregateObject aggr, const LogicalType &result_type, - const WindowExcludeMode exclude_mode_p, idx_t count) - : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, count) { +WindowCustomAggregator::WindowCustomAggregator(AggregateObject aggr, const vector &arg_types, + const LogicalType &result_type, const WindowExcludeMode exclude_mode) + : WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode) { } WindowCustomAggregator::~WindowCustomAggregator() { @@ -282,12 +470,28 @@ static void InitSubFrames(SubFrames &frames, const WindowExcludeMode exclude_mod frames.resize(nframes, {0, 0}); } +class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { +public: + explicit WindowCustomAggregatorGlobalState(const WindowCustomAggregator &aggregator, idx_t group_count) + : WindowAggregatorGlobalState(aggregator, group_count) { + + gcstate = make_uniq(aggregator.aggr, aggregator.exclude_mode); + } + + //! Traditional packed filter mask for API + ValidityMask filter_packed; + //! Data pointer that contains a single local state, used for global custom window execution state + unique_ptr gcstate; + //! Partition description for custom window APIs + unique_ptr partition_input; +}; + WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode) - : aggr(aggr), state(aggr.function.state_size()), statef(Value::POINTER(CastPointerToValue(state.data()))), - frames(3, {0, 0}) { + : aggr(aggr), state(aggr.function.state_size(aggr.function)), + statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { // if we have a frame-by-frame method, share the single state - aggr.function.initialize(state.data()); + aggr.function.initialize(aggr.function, state.data()); InitSubFrames(frames, exclude_mode); } @@ -299,21 +503,41 @@ WindowCustomAggregatorState::~WindowCustomAggregatorState() { } } -void WindowCustomAggregator::Finalize(const FrameStats &stats) { - WindowAggregator::Finalize(stats); - partition_input = - make_uniq(inputs.data.data(), inputs.ColumnCount(), inputs.size(), filter_mask, stats); +unique_ptr WindowCustomAggregator::GetGlobalState(idx_t group_count, + const ValidityMask &) const { + return make_uniq(*this, group_count); +} + +void WindowCustomAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const FrameStats &stats) { + // Single threaded Finalize for now + auto &gcsink = gsink.Cast(); + lock_guard gestate_guard(gcsink.lock); + if (gcsink.finalized) { + return; + } + + WindowAggregator::Finalize(gsink, lstate, stats); + + auto &inputs = gcsink.inputs; + auto &filter_mask = gcsink.filter_mask; + auto &filter_packed = gcsink.filter_packed; + filter_mask.Pack(filter_packed, filter_mask.target_count); + + gcsink.partition_input = + make_uniq(inputs.data.data(), inputs.ColumnCount(), inputs.size(), filter_packed, stats); if (aggr.function.window_init) { - gstate = GetLocalState(); - auto &gcstate = gstate->Cast(); + auto &gcstate = *gcsink.gcstate; AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); - aggr.function.window_init(aggr_input_data, *partition_input, gcstate.state.data()); + aggr.function.window_init(aggr_input_data, *gcsink.partition_input, gcstate.state.data()); } + + ++gcsink.finalized; } -unique_ptr WindowCustomAggregator::GetLocalState() const { +unique_ptr WindowCustomAggregator::GetLocalState(const WindowAggregatorState &gstate) const { return make_uniq(aggr, exclude_mode); } @@ -374,29 +598,30 @@ static void EvaluateSubFrames(const DataChunk &bounds, const WindowExcludeMode e } } -void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx) const { +void WindowCustomAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { auto &lcstate = lstate.Cast(); auto &frames = lcstate.frames; const_data_ptr_t gstate_p = nullptr; - if (gstate) { - auto &gcstate = gstate->Cast(); - gstate_p = gcstate.state.data(); + auto &gcsink = gsink.Cast(); + if (gcsink.gcstate) { + gstate_p = gcsink.gcstate->state.data(); } EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { // Extract the range AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); - aggr.function.window(aggr_input_data, *partition_input, gstate_p, lcstate.state.data(), frames, result, i); + aggr.function.window(aggr_input_data, *gcsink.partition_input, gstate_p, lcstate.state.data(), frames, result, + i); }); } //===--------------------------------------------------------------------===// // WindowNaiveAggregator //===--------------------------------------------------------------------===// -WindowNaiveAggregator::WindowNaiveAggregator(AggregateObject aggr, const LogicalType &result_type, - const WindowExcludeMode exclude_mode_p, idx_t partition_count) - : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, partition_count) { +WindowNaiveAggregator::WindowNaiveAggregator(AggregateObject aggr, const vector &arg_types, + const LogicalType &result_type, const WindowExcludeMode exclude_mode) + : WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode) { } WindowNaiveAggregator::~WindowNaiveAggregator() { @@ -405,44 +630,47 @@ WindowNaiveAggregator::~WindowNaiveAggregator() { class WindowNaiveState : public WindowAggregatorState { public: struct HashRow { - explicit HashRow(WindowNaiveState &state) : state(state) { + HashRow(WindowNaiveState &state, const DataChunk &inputs) : state(state), inputs(inputs) { } size_t operator()(const idx_t &i) const { - return state.Hash(i); + return state.Hash(inputs, i); } WindowNaiveState &state; + const DataChunk &inputs; }; struct EqualRow { - explicit EqualRow(WindowNaiveState &state) : state(state) { + EqualRow(WindowNaiveState &state, const DataChunk &inputs) : state(state), inputs(inputs) { } bool operator()(const idx_t &lhs, const idx_t &rhs) const { - return state.KeyEqual(lhs, rhs); + return state.KeyEqual(inputs, lhs, rhs); } WindowNaiveState &state; + const DataChunk &inputs; }; using RowSet = std::unordered_set; - explicit WindowNaiveState(const WindowNaiveAggregator &gstate); + explicit WindowNaiveState(const WindowNaiveAggregator &gsink); - void Evaluate(const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx); + void Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx); protected: //! Flush the accumulated intermediate states into the result states - void FlushStates(); + void FlushStates(const WindowAggregatorGlobalState &gsink); //! Hashes a value for the hash table - size_t Hash(idx_t rid); + size_t Hash(const DataChunk &inputs, idx_t rid); //! Compares two values for the hash table - bool KeyEqual(const idx_t &lhs, const idx_t &rhs); + bool KeyEqual(const DataChunk &inputs, const idx_t &lhs, const idx_t &rhs); //! The global state - const WindowNaiveAggregator &gstate; + const WindowNaiveAggregator &aggregator; //! Data pointer that contains a vector of states, used for row aggregation vector state; //! Reused result state container for the aggregate @@ -459,21 +687,12 @@ class WindowNaiveState : public WindowAggregatorState { SubFrames frames; //! The optional hash table used for DISTINCT Vector hashes; - HashRow hash_row; - EqualRow equal_row; - RowSet row_set; }; -WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &gstate) - : gstate(gstate), state(gstate.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), - statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH), hash_row(*this), equal_row(*this), - row_set(STANDARD_VECTOR_SIZE, hash_row, equal_row) { - InitSubFrames(frames, gstate.exclude_mode); - - auto &inputs = gstate.GetInputs(); - if (inputs.ColumnCount() > 0) { - leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); - } +WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &aggregator_p) + : aggregator(aggregator_p), state(aggregator.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), + statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH) { + InitSubFrames(frames, aggregator.exclude_mode); update_sel.Initialize(); @@ -485,28 +704,26 @@ WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &gstate) auto fdata = FlatVector::GetData(statef); for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { fdata[i] = state_ptr; - state_ptr += gstate.state_size; + state_ptr += aggregator.state_size; } } -void WindowNaiveState::FlushStates() { +void WindowNaiveState::FlushStates(const WindowAggregatorGlobalState &gsink) { if (!flush_count) { return; } - auto &inputs = gstate.GetInputs(); + auto &inputs = gsink.inputs; leaves.Slice(inputs, update_sel, flush_count); - auto &aggr = gstate.aggr; + auto &aggr = aggregator.aggr; AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), statep, flush_count); flush_count = 0; } -size_t WindowNaiveState::Hash(idx_t rid) { - auto &inputs = gstate.GetInputs(); - +size_t WindowNaiveState::Hash(const DataChunk &inputs, idx_t rid) { auto s = UnsafeNumericCast(rid); SelectionVector sel(&s); leaves.Slice(inputs, sel, 1); @@ -515,9 +732,7 @@ size_t WindowNaiveState::Hash(idx_t rid) { return *FlatVector::GetData(hashes); } -bool WindowNaiveState::KeyEqual(const idx_t &lhs, const idx_t &rhs) { - auto &inputs = gstate.GetInputs(); - +bool WindowNaiveState::KeyEqual(const DataChunk &inputs, const idx_t &lhs, const idx_t &rhs) { auto l = UnsafeNumericCast(lhs); SelectionVector lsel(&l); @@ -538,16 +753,26 @@ bool WindowNaiveState::KeyEqual(const idx_t &lhs, const idx_t &rhs) { return true; } -void WindowNaiveState::Evaluate(const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { - auto &aggr = gstate.aggr; - auto &filter_mask = gstate.GetFilterMask(); +void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx) { + auto &aggr = aggregator.aggr; + auto &filter_mask = gsink.filter_mask; + auto &inputs = gsink.inputs; + + if (leaves.ColumnCount() == 0 && inputs.ColumnCount() > 0) { + leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); + } auto fdata = FlatVector::GetData(statef); auto pdata = FlatVector::GetData(statep); - EvaluateSubFrames(bounds, gstate.exclude_mode, count, row_idx, frames, [&](idx_t rid) { + HashRow hash_row(*this, inputs); + EqualRow equal_row(*this, inputs); + RowSet row_set(STANDARD_VECTOR_SIZE, hash_row, equal_row); + + EvaluateSubFrames(bounds, aggregator.exclude_mode, count, row_idx, frames, [&](idx_t rid) { auto agg_state = fdata[rid]; - aggr.function.initialize(agg_state); + aggr.function.initialize(aggr.function, agg_state); // Just update the aggregate with the unfiltered input rows row_set.clear(); @@ -565,14 +790,14 @@ void WindowNaiveState::Evaluate(const DataChunk &bounds, Vector &result, idx_t c pdata[flush_count] = agg_state; update_sel[flush_count++] = UnsafeNumericCast(f); if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(); + FlushStates(gsink); } } } }); // Flush the final states - FlushStates(); + FlushStates(gsink); // Finalise the result aggregates and write to the result AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); @@ -584,55 +809,57 @@ void WindowNaiveState::Evaluate(const DataChunk &bounds, Vector &result, idx_t c } } -unique_ptr WindowNaiveAggregator::GetLocalState() const { +unique_ptr WindowNaiveAggregator::GetLocalState(const WindowAggregatorState &gstate) const { return make_uniq(*this); } -void WindowNaiveAggregator::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx) const { - auto &ldstate = lstate.Cast(); - ldstate.Evaluate(bounds, result, count, row_idx); +void WindowNaiveAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { + const auto &gnstate = gsink.Cast(); + auto &lnstate = lstate.Cast(); + lnstate.Evaluate(gnstate, bounds, result, count, row_idx); } //===--------------------------------------------------------------------===// // WindowSegmentTree //===--------------------------------------------------------------------===// -WindowSegmentTree::WindowSegmentTree(AggregateObject aggr, const LogicalType &result_type, WindowAggregationMode mode_p, - const WindowExcludeMode exclude_mode_p, idx_t count) - : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, count), internal_nodes(0), mode(mode_p) { -} +class WindowSegmentTreeGlobalState : public WindowAggregatorGlobalState { +public: + using AtomicCounters = vector>; -void WindowSegmentTree::Finalize(const FrameStats &stats) { - WindowAggregator::Finalize(stats); + WindowSegmentTreeGlobalState(const WindowSegmentTree &aggregator, idx_t group_count); - gstate = GetLocalState(); - if (inputs.ColumnCount() > 0) { - if (aggr.function.combine && UseCombineAPI()) { - ConstructTree(); - } + ArenaAllocator &CreateTreeAllocator() { + lock_guard tree_lock(lock); + tree_allocators.emplace_back(make_uniq(Allocator::DefaultAllocator())); + return *tree_allocators.back(); } -} -WindowSegmentTree::~WindowSegmentTree() { - if (!aggr.function.destructor || !gstate) { - // nothing to destroy - return; - } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gstate->allocator); - // call the destructor for all the intermediate states - data_ptr_t address_data[STANDARD_VECTOR_SIZE]; - Vector addresses(LogicalType::POINTER, data_ptr_cast(address_data)); - idx_t count = 0; - for (idx_t i = 0; i < internal_nodes; i++) { - address_data[count++] = data_ptr_t(levels_flat_native.get() + i * state_size); - if (count == STANDARD_VECTOR_SIZE) { - aggr.function.destructor(addresses, aggr_input_data, count); - count = 0; - } - } - if (count > 0) { - aggr.function.destructor(addresses, aggr_input_data, count); - } + //! The owning aggregator + const WindowSegmentTree &tree; + //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes + WindowAggregateStates levels_flat_native; + //! For each level, the starting location in the levels_flat_native array + vector levels_flat_start; + //! The level being built (read) + std::atomic build_level; + //! The number of entries started so far at each level + unique_ptr build_started; + //! The number of entries completed so far at each level + unique_ptr build_completed; + //! The tree allocators. + //! We need to hold onto them for the tree lifetime, + //! not the lifetime of the local state that constructed part of the tree + vector> tree_allocators; + + // TREE_FANOUT needs to cleanly divide STANDARD_VECTOR_SIZE + static constexpr idx_t TREE_FANOUT = 16; +}; + +WindowSegmentTree::WindowSegmentTree(AggregateObject aggr, const vector &arg_types, + const LogicalType &result_type, WindowAggregationMode mode_p, + const WindowExcludeMode exclude_mode_p) + : WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode_p), mode(mode_p) { } class WindowSegmentTreePart { @@ -643,7 +870,7 @@ class WindowSegmentTreePart { enum FramePart : uint8_t { FULL = 0, LEFT = 1, RIGHT = 2 }; WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, const DataChunk &inputs, - const ValidityMask &filter_mask); + const ValidityArray &filter_mask); ~WindowSegmentTreePart(); unique_ptr Copy() const { @@ -652,23 +879,23 @@ class WindowSegmentTreePart { void FlushStates(bool combining); void ExtractFrame(idx_t begin, idx_t end, data_ptr_t current_state); - void WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, + void WindowSegmentValue(const WindowSegmentTreeGlobalState &tree, idx_t l_idx, idx_t begin, idx_t end, data_ptr_t current_state); //! Writes result and calls destructors void Finalize(Vector &result, idx_t count); void Combine(WindowSegmentTreePart &other, idx_t count); - void Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, Vector &result, idx_t count, - idx_t row_idx, FramePart frame_part); + void Evaluate(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, Vector &result, + idx_t count, idx_t row_idx, FramePart frame_part); protected: //! Initialises the accumulation state vector (statef) void Initialize(idx_t count); //! Accumulate upper tree levels - void EvaluateUpperLevels(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, - idx_t row_idx, FramePart frame_part); - void EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, + void EvaluateUpperLevels(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, + idx_t count, idx_t row_idx, FramePart frame_part); + void EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part, FramePart leaf_part); public: @@ -681,7 +908,7 @@ class WindowSegmentTreePart { //! The partition arguments const DataChunk &inputs; //! The filtered rows in inputs - const ValidityMask &filter_mask; + const ValidityArray &filter_mask; //! The size of a single aggregate state const idx_t state_size; //! Data pointer that contains a vector of states, used for intermediate window segment aggregation @@ -704,28 +931,41 @@ class WindowSegmentTreePart { class WindowSegmentTreeState : public WindowAggregatorState { public: - WindowSegmentTreeState(const AggregateObject &aggr, const DataChunk &inputs, const ValidityMask &filter_mask) - : aggr(aggr), inputs(inputs), filter_mask(filter_mask), part(allocator, aggr, inputs, filter_mask) { + WindowSegmentTreeState() { } - //! The aggregate function - const AggregateObject &aggr; - //! The aggregate function - const DataChunk &inputs; - //! The filtered rows in inputs - const ValidityMask &filter_mask; + void Finalize(WindowSegmentTreeGlobalState &gstate); + void Evaluate(const WindowSegmentTreeGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx); //! The left (default) segment tree part - WindowSegmentTreePart part; + unique_ptr part; //! The right segment tree part (for EXCLUDE) unique_ptr right_part; }; +void WindowSegmentTree::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, const FrameStats &stats) { + + auto &gasink = gsink.Cast(); + auto &inputs = gasink.inputs; + + WindowAggregator::Finalize(gsink, lstate, stats); + + if (inputs.ColumnCount() > 0) { + if (aggr.function.combine && UseCombineAPI()) { + lstate.Cast().Finalize(gasink); + } + } + + ++gasink.finalized; +} + WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, - const DataChunk &inputs, const ValidityMask &filter_mask) + const DataChunk &inputs, const ValidityArray &filter_mask) : allocator(allocator), aggr(aggr), order_insensitive(aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT), inputs(inputs), - filter_mask(filter_mask), state_size(aggr.function.state_size()), state(state_size * STANDARD_VECTOR_SIZE), - statep(LogicalType::POINTER), statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { + filter_mask(filter_mask), state_size(aggr.function.state_size(aggr.function)), + state(state_size * STANDARD_VECTOR_SIZE), statep(LogicalType::POINTER), statel(LogicalType::POINTER), + statef(LogicalType::POINTER), flush_count(0) { if (inputs.ColumnCount() > 0) { leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); filter_sel.Initialize(); @@ -746,8 +986,13 @@ WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const Ag WindowSegmentTreePart::~WindowSegmentTreePart() { } -unique_ptr WindowSegmentTree::GetLocalState() const { - return make_uniq(aggr, inputs, filter_mask); +unique_ptr WindowSegmentTree::GetGlobalState(idx_t group_count, + const ValidityMask &partition_mask) const { + return make_uniq(*this, group_count); +} + +unique_ptr WindowSegmentTree::GetLocalState(const WindowAggregatorState &gstate) const { + return make_uniq(); } void WindowSegmentTreePart::FlushStates(bool combining) { @@ -800,8 +1045,8 @@ void WindowSegmentTreePart::ExtractFrame(idx_t begin, idx_t end, data_ptr_t stat } } -void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, - data_ptr_t state_ptr) { +void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTreeGlobalState &tree, idx_t l_idx, idx_t begin, + idx_t end, data_ptr_t state_ptr) { D_ASSERT(begin <= end); if (begin == end || inputs.ColumnCount() == 0) { return; @@ -812,9 +1057,9 @@ void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTree &tree, id ExtractFrame(begin, end, state_ptr); } else { // find out where the states begin - auto begin_ptr = tree.levels_flat_native.get() + state_size * (begin + tree.levels_flat_start[l_idx - 1]); + auto begin_ptr = tree.levels_flat_native.GetStatePtr(begin + tree.levels_flat_start[l_idx - 1]); // set up a vector of pointers that point towards the set of states - auto ldata = FlatVector::GetData(statel); + auto ldata = FlatVector::GetData(statel); auto pdata = FlatVector::GetData(statep); for (idx_t i = 0; i < count; i++) { pdata[flush_count] = state_ptr; @@ -837,20 +1082,12 @@ void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { } } -void WindowSegmentTree::ConstructTree() { - D_ASSERT(inputs.ColumnCount() > 0); +WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(const WindowSegmentTree &aggregator, idx_t group_count) + : WindowAggregatorGlobalState(aggregator, group_count), tree(aggregator), levels_flat_native(aggregator.aggr) { - // Use a temporary scan state to build the tree - auto >state = gstate->Cast().part; + D_ASSERT(inputs.ColumnCount() > 0); // compute space required to store internal nodes of segment tree - internal_nodes = 0; - idx_t level_nodes = inputs.size(); - do { - level_nodes = (level_nodes + (TREE_FANOUT - 1)) / TREE_FANOUT; - internal_nodes += level_nodes; - } while (level_nodes > 1); - levels_flat_native = make_unsafe_uniq_array(internal_nodes * state_size); levels_flat_start.push_back(0); idx_t levels_flat_offset = 0; @@ -861,12 +1098,6 @@ void WindowSegmentTree::ConstructTree() { while ((level_size = (level_current == 0 ? inputs.size() : levels_flat_offset - levels_flat_start[level_current - 1])) > 1) { for (idx_t pos = 0; pos < level_size; pos += TREE_FANOUT) { - // compute the aggregate for this entry in the segment tree - data_ptr_t state_ptr = levels_flat_native.get() + (levels_flat_offset * state_size); - aggr.function.initialize(state_ptr); - gtstate.WindowSegmentValue(*this, level_current, pos, MinValue(level_size, pos + TREE_FANOUT), state_ptr); - gtstate.FlushStates(level_current > 0); - levels_flat_offset++; } @@ -876,46 +1107,120 @@ void WindowSegmentTree::ConstructTree() { // Corner case: single element in the window if (levels_flat_offset == 0) { - aggr.function.initialize(levels_flat_native.get()); + ++levels_flat_offset; + } + + levels_flat_native.Initialize(levels_flat_offset); + + // Start by building from the bottom level + build_level = 0; + + build_started = make_uniq(levels_flat_start.size()); + for (auto &counter : *build_started) { + counter = 0; + } + + build_completed = make_uniq(levels_flat_start.size()); + for (auto &counter : *build_completed) { + counter = 0; } } -void WindowSegmentTree::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowSegmentTreeState::Finalize(WindowSegmentTreeGlobalState &gstate) { + // Single part for constructing the tree + auto &inputs = gstate.inputs; + auto &tree = gstate.tree; + auto &filter_mask = gstate.filter_mask; + WindowSegmentTreePart gtstate(gstate.CreateTreeAllocator(), tree.aggr, inputs, filter_mask); + + auto &levels_flat_native = gstate.levels_flat_native; + const auto &levels_flat_start = gstate.levels_flat_start; + // iterate over the levels of the segment tree + for (;;) { + const idx_t level_current = gstate.build_level.load(); + if (level_current >= levels_flat_start.size()) { + break; + } + + // level 0 is data itself + const auto level_size = + (level_current == 0 ? inputs.size() + : levels_flat_start[level_current] - levels_flat_start[level_current - 1]); + if (level_size <= 1) { + break; + } + const idx_t build_count = (level_size + gstate.TREE_FANOUT - 1) / gstate.TREE_FANOUT; + + // Build the next fan-in + const idx_t build_idx = (*gstate.build_started).at(level_current)++; + if (build_idx >= build_count) { + // Nothing left at this level, so wait until other threads are done. + // Since we are only building TREE_FANOUT values at a time, this will be quick. + while (level_current == gstate.build_level.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + continue; + } + + // compute the aggregate for this entry in the segment tree + const idx_t pos = build_idx * gstate.TREE_FANOUT; + const idx_t levels_flat_offset = levels_flat_start[level_current] + build_idx; + auto state_ptr = levels_flat_native.GetStatePtr(levels_flat_offset); + gtstate.WindowSegmentValue(gstate, level_current, pos, MinValue(level_size, pos + gstate.TREE_FANOUT), + state_ptr); + gtstate.FlushStates(level_current > 0); + + // If that was the last one, mark the level as complete. + const idx_t build_complete = ++(*gstate.build_completed).at(level_current); + if (build_complete == build_count) { + gstate.build_level++; + continue; + } + } +} +void WindowSegmentTree::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { + const auto >state = gsink.Cast(); auto <state = lstate.Cast(); + ltstate.Evaluate(gtstate, bounds, result, count, row_idx); +} + +void WindowSegmentTreeState::Evaluate(const WindowSegmentTreeGlobalState >state, const DataChunk &bounds, + Vector &result, idx_t count, idx_t row_idx) { auto window_begin = FlatVector::GetData(bounds.data[WINDOW_BEGIN]); auto window_end = FlatVector::GetData(bounds.data[WINDOW_END]); auto peer_begin = FlatVector::GetData(bounds.data[PEER_BEGIN]); auto peer_end = FlatVector::GetData(bounds.data[PEER_END]); - auto &part = ltstate.part; - if (exclude_mode != WindowExcludeMode::NO_OTHER) { + if (!part) { + part = + make_uniq(allocator, gtstate.aggregator.aggr, gtstate.inputs, gtstate.filter_mask); + } + + if (gtstate.aggregator.exclude_mode != WindowExcludeMode::NO_OTHER) { // 1. evaluate the tree left of the excluded part - part.Evaluate(*this, window_begin, peer_begin, result, count, row_idx, WindowSegmentTreePart::LEFT); + part->Evaluate(gtstate, window_begin, peer_begin, result, count, row_idx, WindowSegmentTreePart::LEFT); // 2. set up a second state for the right of the excluded part - if (!ltstate.right_part) { - ltstate.right_part = part.Copy(); + if (!right_part) { + right_part = part->Copy(); } - auto &right_part = *ltstate.right_part; // 3. evaluate the tree right of the excluded part - right_part.Evaluate(*this, peer_end, window_end, result, count, row_idx, WindowSegmentTreePart::RIGHT); + right_part->Evaluate(gtstate, peer_end, window_end, result, count, row_idx, WindowSegmentTreePart::RIGHT); // 4. combine the buffer state into the Segment Tree State - part.Combine(right_part, count); + part->Combine(*right_part, count); } else { - part.Evaluate(*this, window_begin, window_end, result, count, row_idx, WindowSegmentTreePart::FULL); + part->Evaluate(gtstate, window_begin, window_end, result, count, row_idx, WindowSegmentTreePart::FULL); } - part.Finalize(result, count); + part->Finalize(result, count); } -void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, +void WindowSegmentTreePart::Evaluate(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { - D_ASSERT(aggr.function.combine && tree.UseCombineAPI()); - Initialize(count); if (order_insensitive) { @@ -936,15 +1241,15 @@ void WindowSegmentTreePart::Initialize(idx_t count) { auto fdata = FlatVector::GetData(statef); for (idx_t rid = 0; rid < count; ++rid) { auto state_ptr = fdata[rid]; - aggr.function.initialize(state_ptr); + aggr.function.initialize(aggr.function, state_ptr); } } -void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, - idx_t count, idx_t row_idx, FramePart frame_part) { +void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, + const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part) { auto fdata = FlatVector::GetData(statef); - const auto exclude_mode = tree.exclude_mode; + const auto exclude_mode = tree.tree.exclude_mode; const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; @@ -1034,8 +1339,9 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, c FlushStates(true); } -void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, - idx_t count, idx_t row_idx, FramePart frame_part, FramePart leaf_part) { +void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, + const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part, + FramePart leaf_part) { auto fdata = FlatVector::GetData(statef); @@ -1044,7 +1350,7 @@ void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const // The current row is the leftmost value of the right hand side. const bool compute_left = leaf_part != FramePart::RIGHT; const bool compute_right = leaf_part != FramePart::LEFT; - const auto exclude_mode = tree.exclude_mode; + const auto exclude_mode = tree.tree.exclude_mode; const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; // with EXCLUDE TIES, in addition to the frame part right of the peer group's end, we also need to consider the @@ -1087,81 +1393,236 @@ void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const //===--------------------------------------------------------------------===// // WindowDistinctAggregator //===--------------------------------------------------------------------===// -WindowDistinctAggregator::WindowDistinctAggregator(AggregateObject aggr, const LogicalType &result_type, - const WindowExcludeMode exclude_mode_p, idx_t count, - ClientContext &context) - : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, count), context(context), - allocator(Allocator::DefaultAllocator()) { +WindowDistinctAggregator::WindowDistinctAggregator(AggregateObject aggr, const vector &arg_types, + const LogicalType &result_type, + const WindowExcludeMode exclude_mode_p, ClientContext &context) + : WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode_p), context(context) { +} + +class WindowDistinctAggregatorLocalState; + +class WindowDistinctAggregatorGlobalState; + +class WindowDistinctSortTree : public MergeSortTree { +public: + // prev_idx, input_idx + using ZippedTuple = std::tuple; + using ZippedElements = vector; + + explicit WindowDistinctSortTree(WindowDistinctAggregatorGlobalState &gdastate, idx_t count) : gdastate(gdastate) { + // Set up for parallel build + build_level = 0; + build_complete = 0; + build_run = 0; + build_run_length = 1; + build_num_runs = count; + } + + void Build(WindowDistinctAggregatorLocalState &ldastate); + +protected: + bool TryNextRun(idx_t &level_idx, idx_t &run_idx); + void BuildRun(idx_t level_nr, idx_t i, WindowDistinctAggregatorLocalState &ldastate); + WindowDistinctAggregatorGlobalState &gdastate; +}; + +class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { +public: + using GlobalSortStatePtr = unique_ptr; + using ZippedTuple = WindowDistinctSortTree::ZippedTuple; + using ZippedElements = WindowDistinctSortTree::ZippedElements; + + WindowDistinctAggregatorGlobalState(const WindowDistinctAggregator &aggregator, idx_t group_count); + + //! Compute the block starts + void MeasurePayloadBlocks(); + //! Patch up the previous index block boundaries + void PatchPrevIdcs(); + bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate); + + // Single threaded sorting for now + ClientContext &context; + idx_t memory_per_thread; + + //! Finalize guard + mutex lock; + //! Finalize stage + atomic stage; + //! Tasks launched + idx_t total_tasks = 0; + //! Tasks launched + idx_t tasks_assigned = 0; + //! Tasks landed + mutable atomic tasks_completed; + + //! The sorted payload data types (partition index) + vector payload_types; + //! The aggregate arguments + partition index + vector sort_types; + + //! Sorting operations + GlobalSortStatePtr global_sort; + //! The block starts (the scanner doesn't know this) plus the total count + vector block_starts; + + //! The block boundary seconds + mutable ZippedElements seconds; + //! The MST with the distinct back pointers + mutable MergeSortTree zipped_tree; + //! The merge sort tree for the aggregate. + WindowDistinctSortTree merge_sort_tree; + + //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes + WindowAggregateStates levels_flat_native; + //! For each level, the starting location in the levels_flat_native array + vector levels_flat_start; +}; + +WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(const WindowDistinctAggregator &aggregator, + idx_t group_count) + : WindowAggregatorGlobalState(aggregator, group_count), context(aggregator.context), + stage(PartitionSortStage::INIT), tasks_completed(0), merge_sort_tree(*this, group_count), + levels_flat_native(aggregator.aggr) { payload_types.emplace_back(LogicalType::UBIGINT); - payload_chunk.Initialize(Allocator::DefaultAllocator(), payload_types); -} -WindowDistinctAggregator::~WindowDistinctAggregator() { - if (!aggr.function.destructor) { - // nothing to destroy - return; + // 1: functionComputePrevIdcs(𝑖𝑛) + // 2: sorted ← [] + // We sort the aggregate arguments and use the partition index as a tie-breaker. + // TODO: Use a hash table? + sort_types = aggregator.arg_types; + for (const auto &type : payload_types) { + sort_types.emplace_back(type); } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - // call the destructor for all the intermediate states - data_ptr_t address_data[STANDARD_VECTOR_SIZE]; - Vector addresses(LogicalType::POINTER, data_ptr_cast(address_data)); - idx_t count = 0; - for (idx_t i = 0; i < internal_nodes; i++) { - address_data[count++] = data_ptr_t(levels_flat_native.get() + i * state_size); - if (count == STANDARD_VECTOR_SIZE) { - aggr.function.destructor(addresses, aggr_input_data, count); - count = 0; - } + + vector orders; + for (const auto &type : sort_types) { + auto expr = make_uniq(Value(type)); + orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr))); } - if (count > 0) { - aggr.function.destructor(addresses, aggr_input_data, count); + + RowLayout payload_layout; + payload_layout.Initialize(payload_types); + + global_sort = make_uniq(BufferManager::GetBufferManager(context), orders, payload_layout); + + memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); + + // 6: prevIdcs ← [] + // 7: prevIdcs[0] ← “-” + auto &prev_idcs = zipped_tree.Allocate(group_count); + + // To handle FILTER clauses we make the missing elements + // point to themselves so they won't be counted. + for (idx_t i = 0; i < group_count; ++i) { + prev_idcs[i] = ZippedTuple(i + 1, i); + } + + // compute space required to store aggregation states of merge sort tree + // this is one aggregate state per entry per level + idx_t internal_nodes = 0; + levels_flat_start.push_back(internal_nodes); + for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { + internal_nodes += zipped_tree.tree[level_nr].first.size(); + levels_flat_start.push_back(internal_nodes); + } + levels_flat_native.Initialize(internal_nodes); + + merge_sort_tree.tree.reserve(zipped_tree.tree.size()); + for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { + auto &zipped_level = zipped_tree.tree[level_nr].first; + WindowDistinctSortTree::Elements level; + WindowDistinctSortTree::Offsets cascades; + level.resize(zipped_level.size()); + merge_sort_tree.tree.emplace_back(std::move(level), std::move(cascades)); } } -void WindowDistinctAggregator::Sink(DataChunk &arg_chunk, SelectionVector *filter_sel, idx_t filtered) { - WindowAggregator::Sink(arg_chunk, filter_sel, filtered); +class WindowDistinctAggregatorLocalState : public WindowAggregatorState { +public: + explicit WindowDistinctAggregatorLocalState(const WindowDistinctAggregatorGlobalState &aggregator); + + void Sink(DataChunk &arg_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered); + void Sorted(); + void ExecuteTask(); + void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx); + + //! Thread-local sorting data + LocalSortState local_sort; + //! Finalize stage + PartitionSortStage stage = PartitionSortStage::INIT; + //! Finalize scan block index + idx_t block_idx; + //! Thread-local tree aggregation + Vector update_v; + Vector source_v; + Vector target_v; + DataChunk leaves; + SelectionVector sel; - // We sort the arguments and use the partition index as a tie-breaker. - // TODO: Use a hash table? - if (!global_sort) { - // 1: functionComputePrevIdcs(𝑖𝑛) - // 2: sorted ← [] - vector sort_types; - for (const auto &col : arg_chunk.data) { - sort_types.emplace_back(col.GetType()); - } +protected: + //! Flush the accumulated intermediate states into the result states + void FlushStates(); - for (const auto &type : payload_types) { - sort_types.emplace_back(type); - } + //! The aggregator we are working with + const WindowDistinctAggregatorGlobalState &gastate; + DataChunk sort_chunk; + DataChunk payload_chunk; + //! Reused result state container for the window functions + WindowAggregateStates statef; + //! A vector of pointers to "state", used for buffering intermediate aggregates + Vector statep; + //! Reused state pointers for combining tree elements + Vector statel; + //! Count of buffered values + idx_t flush_count; + //! The frame boundaries, used for the window functions + SubFrames frames; +}; - vector orders; - for (const auto &type : sort_types) { - auto expr = make_uniq(Value(type)); - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr))); - } +WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( + const WindowDistinctAggregatorGlobalState &gastate) + : update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), target_v(LogicalType::POINTER), gastate(gastate), + statef(gastate.aggregator.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { + InitSubFrames(frames, gastate.aggregator.exclude_mode); + payload_chunk.Initialize(Allocator::DefaultAllocator(), gastate.payload_types); - RowLayout payload_layout; - payload_layout.Initialize(payload_types); + auto &global_sort = gastate.global_sort; + local_sort.Initialize(*global_sort, global_sort->buffer_manager); - global_sort = make_uniq(BufferManager::GetBufferManager(context), orders, payload_layout); - local_sort.Initialize(*global_sort, global_sort->buffer_manager); + sort_chunk.Initialize(Allocator::DefaultAllocator(), gastate.sort_types); + sort_chunk.data.back().Reference(payload_chunk.data[0]); - sort_chunk.Initialize(Allocator::DefaultAllocator(), sort_types); - sort_chunk.data.back().Reference(payload_chunk.data[0]); - payload_pos = 0; - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - } + //! Input data chunk, used for leaf segment aggregation + leaves.Initialize(Allocator::DefaultAllocator(), gastate.inputs.GetTypes()); + sel.Initialize(); + + gastate.locals++; +} + +unique_ptr WindowDistinctAggregator::GetGlobalState(idx_t group_count, + const ValidityMask &partition_mask) const { + return make_uniq(*this, group_count); +} + +void WindowDistinctAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &arg_chunk, + idx_t input_idx, optional_ptr filter_sel, idx_t filtered) { + WindowAggregator::Sink(gsink, lstate, arg_chunk, input_idx, filter_sel, filtered); + auto &ldstate = lstate.Cast(); + ldstate.Sink(arg_chunk, input_idx, filter_sel, filtered); +} + +void WindowDistinctAggregatorLocalState::Sink(DataChunk &arg_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered) { // 3: for i ← 0 to in.size do // 4: sorted[i] ← (in[i], i) const auto count = arg_chunk.size(); payload_chunk.Reset(); auto &sorted_vec = payload_chunk.data[0]; auto sorted = FlatVector::GetData(sorted_vec); - std::iota(sorted, sorted + count, payload_pos); - payload_pos += count; + std::iota(sorted, sorted + count, input_idx); for (column_t c = 0; c < arg_chunk.ColumnCount(); ++c) { sort_chunk.data[c].Reference(arg_chunk.data[c]); @@ -1178,61 +1639,178 @@ void WindowDistinctAggregator::Sink(DataChunk &arg_chunk, SelectionVector *filte local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > memory_per_thread) { - local_sort.Sort(*global_sort, true); + if (local_sort.SizeInBytes() > gastate.memory_per_thread) { + local_sort.Sort(*gastate.global_sort, true); } } -class WindowDistinctAggregator::DistinctSortTree : public MergeSortTree { -public: - // prev_idx, input_idx - using ZippedTuple = std::tuple; - using ZippedElements = vector; +void WindowDistinctAggregatorLocalState::ExecuteTask() { + auto &global_sort = *gastate.global_sort; + switch (stage) { + case PartitionSortStage::INIT: + // AddLocalState is thread-safe + global_sort.AddLocalState(local_sort); + break; + case PartitionSortStage::MERGE: { + MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); + merge_sorter.PerformInMergeRound(); + break; + } + case PartitionSortStage::SORTED: + Sorted(); + break; + default: + break; + } - DistinctSortTree(ZippedElements &&prev_idcs, WindowDistinctAggregator &wda); -}; + ++gastate.tasks_completed; +} -void WindowDistinctAggregator::Finalize(const FrameStats &stats) { - // 5: Sort sorted lexicographically increasing - global_sort->AddLocalState(local_sort); - global_sort->PrepareMergePhase(); - while (global_sort->sorted_blocks.size() > 1) { +void WindowDistinctAggregatorGlobalState::MeasurePayloadBlocks() { + const auto &blocks = global_sort->sorted_blocks[0]->payload_data->data_blocks; + idx_t count = 0; + for (const auto &block : blocks) { + block_starts.emplace_back(count); + count += block->count; + } + block_starts.emplace_back(count); +} + +bool WindowDistinctAggregatorGlobalState::TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate) { + lock_guard stage_guard(lock); + + switch (stage.load()) { + case PartitionSortStage::INIT: + // Wait for all the local sorts to be processed + if (tasks_completed < locals) { + return false; + } + global_sort->PrepareMergePhase(); + if (!(global_sort->sorted_blocks.size() / 2)) { + if (global_sort->sorted_blocks.empty()) { + lstate.stage = stage = PartitionSortStage::FINISHED; + return true; + } + MeasurePayloadBlocks(); + seconds.resize(block_starts.size() - 1); + total_tasks = seconds.size(); + tasks_completed = 0; + tasks_assigned = 0; + lstate.stage = stage = PartitionSortStage::SORTED; + lstate.block_idx = tasks_assigned++; + return true; + } global_sort->InitializeMergeRound(); - MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager); - merge_sorter.PerformInMergeRound(); + lstate.stage = stage = PartitionSortStage::MERGE; + total_tasks = locals; + tasks_assigned = 1; + tasks_completed = 0; + return true; + case PartitionSortStage::MERGE: + if (tasks_assigned < total_tasks) { + lstate.stage = PartitionSortStage::MERGE; + ++tasks_assigned; + return true; + } else if (tasks_completed < tasks_assigned) { + return false; + } global_sort->CompleteMergeRound(true); + if (!(global_sort->sorted_blocks.size() / 2)) { + MeasurePayloadBlocks(); + seconds.resize(block_starts.size() - 1); + total_tasks = seconds.size(); + tasks_completed = 0; + tasks_assigned = 0; + lstate.stage = stage = PartitionSortStage::SORTED; + lstate.block_idx = tasks_assigned++; + return true; + } + global_sort->InitializeMergeRound(); + lstate.stage = PartitionSortStage::MERGE; + total_tasks = locals; + tasks_assigned = 1; + tasks_completed = 0; + return true; + case PartitionSortStage::SORTED: + if (tasks_assigned < total_tasks) { + lstate.stage = PartitionSortStage::SORTED; + lstate.block_idx = tasks_assigned++; + return true; + } else if (tasks_completed < tasks_assigned) { + lstate.stage = PartitionSortStage::FINISHED; + // Sleep while other tasks finish + return false; + } + // Last task patches the boundaries + PatchPrevIdcs(); + break; + default: + break; } - DataChunk scan_chunk; - scan_chunk.Initialize(Allocator::DefaultAllocator(), payload_types); + lstate.stage = stage = PartitionSortStage::FINISHED; - auto scanner = make_uniq(*global_sort); - const auto in_size = scanner->Remaining(); - scanner->Scan(scan_chunk); - idx_t scan_idx = 0; + return true; +} - // 6: prevIdcs ← [] - // 7: prevIdcs[0] ← “-” - const auto count = inputs.size(); - using ZippedTuple = DistinctSortTree::ZippedTuple; - DistinctSortTree::ZippedElements prev_idcs; - prev_idcs.resize(count); +void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const FrameStats &stats) { + auto &gdsink = gsink.Cast(); + auto &ldstate = lstate.Cast(); - // To handle FILTER clauses we make the missing elements - // point to themselves so they won't be counted. - if (in_size < count) { - for (idx_t i = 0; i < count; ++i) { - prev_idcs[i] = ZippedTuple(i + 1, i); + // 5: Sort sorted lexicographically increasing + ldstate.ExecuteTask(); + + // Merge in parallel + while (gdsink.stage.load() != PartitionSortStage::FINISHED) { + if (gdsink.TryPrepareNextStage(ldstate)) { + ldstate.ExecuteTask(); + } else { + std::this_thread::yield(); } } + // These are a parallel implementations, + // so every thread can call them. + gdsink.zipped_tree.Build(); + gdsink.merge_sort_tree.Build(ldstate); + + ++gdsink.finalized; +} + +void WindowDistinctAggregatorLocalState::Sorted() { + using ZippedTuple = WindowDistinctAggregatorGlobalState::ZippedTuple; + auto &global_sort = gastate.global_sort; + auto &prev_idcs = gastate.zipped_tree.LowestLevel(); + auto &aggregator = gastate.aggregator; + auto &scan_chunk = payload_chunk; + + auto scanner = make_uniq(*global_sort, block_idx); + const auto in_size = gastate.block_starts.at(block_idx + 1); + scanner->Scan(scan_chunk); + idx_t scan_idx = 0; + auto *input_idx = FlatVector::GetData(scan_chunk.data[0]); - auto i = input_idx[scan_idx++]; - prev_idcs[i] = ZippedTuple(0, i); + idx_t i = 0; SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(sort_chunk.ColumnCount() - 1); + auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(aggregator.arg_types.size()); + + const auto block_begin = gastate.block_starts.at(block_idx); + if (!block_begin) { + // First block, so set up initial sentinel + i = input_idx[scan_idx++]; + prev_idcs[i] = ZippedTuple(0, i); + std::get<0>(gastate.seconds[block_idx]) = i; + } else { + // Move to the to end of the previous block + // so we can record the comparison result for the first row + curr.SetIndex(block_begin - 1); + prev.SetIndex(block_begin - 1); + scan_idx = 0; + std::get<0>(gastate.seconds[block_idx]) = input_idx[scan_idx]; + } // 8: for i ← 1 to in.size do for (++curr; curr.GetIndex() < in_size; ++curr, ++prev) { @@ -1265,105 +1843,148 @@ void WindowDistinctAggregator::Finalize(const FrameStats &stats) { prev_idcs[i] = ZippedTuple(0, i); } } + + // Save the last value of i for patching up the block boundaries + std::get<1>(gastate.seconds[block_idx]) = i; +} + +void WindowDistinctAggregatorGlobalState::PatchPrevIdcs() { // 13: return prevIdcs - merge_sort_tree = make_uniq(std::move(prev_idcs), *this); + // Patch up the indices at block boundaries + // (We don't need to patch block 0.) + auto &prev_idcs = zipped_tree.LowestLevel(); + for (idx_t block_idx = 1; block_idx < seconds.size(); ++block_idx) { + // We only need to patch if the first index in the block + // was a back link to the previous block (10:) + auto i = std::get<0>(seconds.at(block_idx)); + if (std::get<0>(prev_idcs[i])) { + auto second = std::get<1>(seconds.at(block_idx - 1)); + prev_idcs[i] = ZippedTuple(second + 1, i); + } + } +} + +bool WindowDistinctSortTree::TryNextRun(idx_t &level_idx, idx_t &run_idx) { + const auto fanout = FANOUT; + + lock_guard stage_guard(build_lock); + + // Verify we are not done + if (build_level >= tree.size()) { + return false; + } + + // Finished with this level? + if (build_complete >= build_num_runs) { + auto &zipped_tree = gdastate.zipped_tree; + std::swap(tree[build_level].second, zipped_tree.tree[build_level].second); + + ++build_level; + if (build_level >= tree.size()) { + zipped_tree.tree.clear(); + return false; + } + + const auto count = LowestLevel().size(); + build_run_length *= fanout; + build_num_runs = (count + build_run_length - 1) / build_run_length; + build_run = 0; + build_complete = 0; + } + + // If all runs are in flight, + // yield until the next level is ready + if (build_run >= build_num_runs) { + return false; + } + + level_idx = build_level; + run_idx = build_run++; + + return true; +} + +void WindowDistinctSortTree::Build(WindowDistinctAggregatorLocalState &ldastate) { + // Fan in parent levels until we are at the top + // Note that we don't build the top layer as that would just be all the data. + while (build_level.load() < tree.size()) { + idx_t level_idx; + idx_t run_idx; + if (TryNextRun(level_idx, run_idx)) { + BuildRun(level_idx, run_idx, ldastate); + } else { + std::this_thread::yield(); + } + } } -WindowDistinctAggregator::DistinctSortTree::DistinctSortTree(ZippedElements &&prev_idcs, - WindowDistinctAggregator &wda) { - auto &inputs = wda.inputs; - auto &aggr = wda.aggr; - auto &allocator = wda.allocator; - const auto state_size = wda.state_size; - auto &internal_nodes = wda.internal_nodes; - auto &levels_flat_native = wda.levels_flat_native; - auto &levels_flat_start = wda.levels_flat_start; +void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDistinctAggregatorLocalState &ldastate) { + auto &aggr = gdastate.aggregator.aggr; + auto &allocator = gdastate.allocator; + auto &inputs = gdastate.inputs; + auto &levels_flat_native = gdastate.levels_flat_native; //! Input data chunk, used for leaf segment aggregation - DataChunk leaves; - leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); - SelectionVector sel; - sel.Initialize(); + auto &leaves = ldastate.leaves; + auto &sel = ldastate.sel; AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); //! The states to update - Vector update_v(LogicalType::POINTER); + auto &update_v = ldastate.update_v; auto updates = FlatVector::GetData(update_v); - idx_t nupdate = 0; - Vector source_v(LogicalType::POINTER); + auto &source_v = ldastate.source_v; auto sources = FlatVector::GetData(source_v); - Vector target_v(LogicalType::POINTER); + auto &target_v = ldastate.target_v; auto targets = FlatVector::GetData(target_v); - idx_t ncombine = 0; - - // compute space required to store aggregation states of merge sort tree - // this is one aggregate state per entry per level - MergeSortTree zipped_tree(std::move(prev_idcs)); - internal_nodes = 0; - for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { - internal_nodes += zipped_tree.tree[level_nr].first.size(); - } - levels_flat_native = make_unsafe_uniq_array(internal_nodes * state_size); - levels_flat_start.push_back(0); - idx_t levels_flat_offset = 0; - // Walk the distinct value tree building the intermediate aggregates - tree.reserve(zipped_tree.tree.size()); - idx_t level_width = 1; - for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { - auto &zipped_level = zipped_tree.tree[level_nr].first; - vector level; - level.reserve(zipped_level.size()); - - for (idx_t i = 0; i < zipped_level.size(); i += level_width) { - // Reset the combine state - data_ptr_t prev_state = nullptr; - auto next_limit = MinValue(zipped_level.size(), i + level_width); - for (auto j = i; j < next_limit; ++j) { - // Initialise the next aggregate - auto curr_state = levels_flat_native.get() + (levels_flat_offset++ * state_size); - aggr.function.initialize(curr_state); - - // Update this state (if it matches) - const auto prev_idx = std::get<0>(zipped_level[j]); - level.emplace_back(prev_idx); - if (prev_idx < i + 1) { - updates[nupdate] = curr_state; - // input_idx - sel[nupdate] = UnsafeNumericCast(std::get<1>(zipped_level[j])); - ++nupdate; - } + auto &zipped_tree = gdastate.zipped_tree; + auto &zipped_level = zipped_tree.tree[level_nr].first; + auto &level = tree[level_nr].first; - // Merge the previous state (if any) - if (prev_state) { - sources[ncombine] = prev_state; - targets[ncombine] = curr_state; - ++ncombine; - } - prev_state = curr_state; - - // Flush the states if one is maxed out. - if (MaxValue(ncombine, nupdate) >= STANDARD_VECTOR_SIZE) { - // Push the updates first so they propagate - leaves.Reference(inputs); - leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); - nupdate = 0; - - // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); - ncombine = 0; - } - } + // Reset the combine state + idx_t nupdate = 0; + idx_t ncombine = 0; + data_ptr_t prev_state = nullptr; + idx_t i = run_idx * build_run_length; + auto next_limit = MinValue(zipped_level.size(), i + build_run_length); + idx_t levels_flat_offset = level_nr * zipped_level.size() + i; + for (auto j = i; j < next_limit; ++j) { + // Initialise the next aggregate + auto curr_state = levels_flat_native.GetStatePtr(levels_flat_offset++); + + // Update this state (if it matches) + const auto prev_idx = std::get<0>(zipped_level[j]); + level[j] = prev_idx; + if (prev_idx < i + 1) { + updates[nupdate] = curr_state; + // input_idx + sel[nupdate] = UnsafeNumericCast(std::get<1>(zipped_level[j])); + ++nupdate; } - tree.emplace_back(std::move(level), std::move(zipped_tree.tree[level_nr].second)); - - levels_flat_start.push_back(levels_flat_offset); - level_width *= FANOUT; + // Merge the previous state (if any) + if (prev_state) { + sources[ncombine] = prev_state; + targets[ncombine] = curr_state; + ++ncombine; + } + prev_state = curr_state; + + // Flush the states if one is maxed out. + if (MaxValue(ncombine, nupdate) >= STANDARD_VECTOR_SIZE) { + // Push the updates first so they propagate + leaves.Reference(inputs); + leaves.Slice(sel, nupdate); + aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); + nupdate = 0; + + // Combine the states sequentially + aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); + ncombine = 0; + } } // Flush any remaining states @@ -1378,64 +1999,16 @@ WindowDistinctAggregator::DistinctSortTree::DistinctSortTree(ZippedElements &&pr aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); ncombine = 0; } -} - -class WindowDistinctState : public WindowAggregatorState { -public: - WindowDistinctState(const AggregateObject &aggr, const DataChunk &inputs, const WindowDistinctAggregator &tree); - - void Evaluate(const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx); - -protected: - //! Flush the accumulated intermediate states into the result states - void FlushStates(); - - //! The aggregate function - const AggregateObject &aggr; - //! The aggregate function - const DataChunk &inputs; - //! The merge sort tree data - const WindowDistinctAggregator &tree; - //! The size of a single aggregate state - const idx_t state_size; - //! Data pointer that contains a vector of states, used for row aggregation - vector state; - //! Reused result state container for the window functions - Vector statef; - //! A vector of pointers to "state", used for buffering intermediate aggregates - Vector statep; - //! Reused state pointers for combining tree elements - Vector statel; - //! Count of buffered values - idx_t flush_count; - //! The frame boundaries, used for the window functions - SubFrames frames; -}; -WindowDistinctState::WindowDistinctState(const AggregateObject &aggr, const DataChunk &inputs, - const WindowDistinctAggregator &tree) - : aggr(aggr), inputs(inputs), tree(tree), state_size(aggr.function.state_size()), - state((state_size * STANDARD_VECTOR_SIZE)), statef(LogicalType::POINTER), statep(LogicalType::POINTER), - statel(LogicalType::POINTER), flush_count(0) { - InitSubFrames(frames, tree.exclude_mode); - - // Build the finalise vector that just points to the result states - data_ptr_t state_ptr = state.data(); - D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR); - statef.SetVectorType(VectorType::CONSTANT_VECTOR); - statef.Flatten(STANDARD_VECTOR_SIZE); - auto fdata = FlatVector::GetData(statef); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { - fdata[i] = state_ptr; - state_ptr += state_size; - } + ++build_complete; } -void WindowDistinctState::FlushStates() { +void WindowDistinctAggregatorLocalState::FlushStates() { if (!flush_count) { return; } + const auto &aggr = gastate.aggregator.aggr; AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); statel.Verify(flush_count); aggr.function.combine(statel, statep, aggr_input_data, flush_count); @@ -1443,17 +2016,20 @@ void WindowDistinctState::FlushStates() { flush_count = 0; } -void WindowDistinctState::Evaluate(const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { - auto fdata = FlatVector::GetData(statef); - auto ldata = FlatVector::GetData(statel); +void WindowDistinctAggregatorLocalState::Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { + auto ldata = FlatVector::GetData(statel); auto pdata = FlatVector::GetData(statep); - const auto &merge_sort_tree = *tree.merge_sort_tree; - const auto running_aggs = tree.levels_flat_native.get(); + const auto &merge_sort_tree = gdstate.merge_sort_tree; + const auto &levels_flat_native = gdstate.levels_flat_native; + const auto exclude_mode = gdstate.aggregator.exclude_mode; - EvaluateSubFrames(bounds, tree.exclude_mode, count, row_idx, frames, [&](idx_t rid) { - auto agg_state = fdata[rid]; - aggr.function.initialize(agg_state); + // Build the finalise vector that just points to the result states + statef.Initialize(count); + + EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t rid) { + auto agg_state = statef.GetStatePtr(rid); // TODO: Extend AggregateLowerBound to handle subframes, just like SelectNth. const auto lower = frames[0].start; @@ -1463,8 +2039,8 @@ void WindowDistinctState::Evaluate(const DataChunk &bounds, Vector &result, idx_ if (run_pos != run_begin) { // Find the source aggregate // Buffer a merge of the indicated state into the current state - const auto agg_idx = tree.levels_flat_start[level] + run_pos - 1; - const auto running_agg = running_aggs + agg_idx * state_size; + const auto agg_idx = gdstate.levels_flat_start[level] + run_pos - 1; + const auto running_agg = levels_flat_native.GetStatePtr(agg_idx); pdata[flush_count] = agg_state; ldata[flush_count++] = running_agg; if (flush_count >= STANDARD_VECTOR_SIZE) { @@ -1478,23 +2054,20 @@ void WindowDistinctState::Evaluate(const DataChunk &bounds, Vector &result, idx_ FlushStates(); // Finalise the result aggregates and write to the result - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); - - // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); - } + statef.Finalize(result); + statef.Destroy(); } -unique_ptr WindowDistinctAggregator::GetLocalState() const { - return make_uniq(aggr, inputs, *this); +unique_ptr WindowDistinctAggregator::GetLocalState(const WindowAggregatorState &gstate) const { + return make_uniq(gstate.Cast()); } -void WindowDistinctAggregator::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx) const { - auto &ldstate = lstate.Cast(); - ldstate.Evaluate(bounds, result, count, row_idx); +void WindowDistinctAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { + + const auto &gdstate = gsink.Cast(); + auto &ldstate = lstate.Cast(); + ldstate.Evaluate(gdstate, bounds, result, count, row_idx); } } // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp index 10ff1c46..8b5c3be2 100644 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -65,7 +65,7 @@ struct CountFunction : public BaseCountFunction { } static void ConstantOperation(STATE &state, idx_t count) { - state += count; + state += UnsafeNumericCast(count); } static bool IgnoreNull() { @@ -147,7 +147,7 @@ struct CountFunction : public BaseCountFunction { idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); if (ValidityMask::AllValid(validity_entry)) { // all valid - result += next - base_idx; + result += UnsafeNumericCast(next - base_idx); base_idx = next; } else if (ValidityMask::NoneValid(validity_entry)) { // nothing valid: skip all @@ -169,7 +169,7 @@ struct CountFunction : public BaseCountFunction { const SelectionVector &sel_vector) { if (mask.AllValid()) { // no NULL values - result += count; + result += UnsafeNumericCast(count); return; } for (idx_t i = 0; i < count; i++) { @@ -187,7 +187,7 @@ struct CountFunction : public BaseCountFunction { case VectorType::CONSTANT_VECTOR: { if (!ConstantVector::IsNull(input)) { // if the constant is not null increment the state - result += count; + result += UnsafeNumericCast(count); } break; } @@ -197,7 +197,7 @@ struct CountFunction : public BaseCountFunction { } case VectorType::SEQUENCE_VECTOR: { // sequence vectors cannot have NULL values - result += count; + result += UnsafeNumericCast(count); break; } default: { diff --git a/src/duckdb/src/function/aggregate/distributive/first.cpp b/src/duckdb/src/function/aggregate/distributive/first.cpp index 143cf431..8fed2190 100644 --- a/src/duckdb/src/function/aggregate/distributive/first.cpp +++ b/src/duckdb/src/function/aggregate/distributive/first.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/core_functions/create_sort_key.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/planner/expression.hpp" @@ -66,7 +67,7 @@ struct FirstFunction : public FirstFunctionBase { }; template -struct FirstFunctionString : public FirstFunctionBase { +struct FirstFunctionStringBase : public FirstFunctionBase { template static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { if (LAST && state.is_set) { @@ -93,10 +94,28 @@ struct FirstFunctionString : public FirstFunctionBase { } } + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (source.is_set && (LAST || !target.is_set)) { + SetValue(target, input_data, source.value, source.is_null); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &) { + if (state.is_set && !state.is_null && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } +}; + +template +struct FirstFunctionString : FirstFunctionStringBase { template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { if (LAST || !state.is_set) { - SetValue(state, unary_input.input, input, !unary_input.RowIsValid()); + FirstFunctionStringBase::template SetValue(state, unary_input.input, input, + !unary_input.RowIsValid()); } } @@ -106,13 +125,6 @@ struct FirstFunctionString : public FirstFunctionBase { Operation(state, input, unary_input); } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (source.is_set && (LAST || !target.is_set)) { - SetValue(target, input_data, source.value, source.is_null); - } - } - template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (!state.is_set || state.is_null) { @@ -121,48 +133,13 @@ struct FirstFunctionString : public FirstFunctionBase { target = StringVector::AddStringOrBlob(finalize_data.result, state.value); } } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.is_null && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } -}; - -struct FirstStateVector { - Vector *value; }; template -struct FirstVectorFunction { - template - static void Initialize(STATE &state) { - state.value = nullptr; - } +struct FirstVectorFunction : FirstFunctionStringBase { + using STATE = FirstState; - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.value) { - delete state.value; - } - } - static bool IgnoreNull() { - return SKIP_NULLS; - } - - template - static void SetValue(STATE &state, Vector &input, const idx_t idx) { - if (!state.value) { - state.value = new Vector(input.GetType()); - state.value->SetVectorType(VectorType::CONSTANT_VECTOR); - } - sel_t selv = UnsafeNumericCast(idx); - SelectionVector sel(&selv); - VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); - } - - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { + static void Update(Vector inputs[], AggregateInputData &input_data, idx_t, Vector &state_vector, idx_t count) { auto &input = inputs[0]; UnifiedVectorFormat idata; input.ToUnifiedFormat(count, idata); @@ -170,32 +147,61 @@ struct FirstVectorFunction { UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData(sdata); + sel_t assign_sel[STANDARD_VECTOR_SIZE]; + idx_t assign_count = 0; + + auto states = UnifiedVectorFormat::GetData(sdata); for (idx_t i = 0; i < count; i++) { const auto idx = idata.sel->get_index(i); - if (SKIP_NULLS && !idata.validity.RowIsValid(idx)) { + bool is_null = !idata.validity.RowIsValid(idx); + if (SKIP_NULLS && is_null) { continue; } auto &state = *states[sdata.sel->get_index(i)]; - if (LAST || !state.value) { - SetValue(state, input, i); + if (!LAST && state.is_set) { + continue; } + assign_sel[assign_count++] = NumericCast(i); + } + if (assign_count == 0) { + // fast path - nothing to set + return; } - } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.value && (LAST || !target.value)) { - SetValue(target, *source.value, 0); + Vector sort_key(LogicalType::BLOB); + OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + // slice with a selection vector and generate sort keys + if (assign_count == count) { + CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, sort_key); + } else { + SelectionVector sel(assign_sel); + Vector sliced_input(input, sel, assign_count); + CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); + } + auto sort_key_data = FlatVector::GetData(sort_key); + + // now assign sort keys + for (idx_t i = 0; i < assign_count; i++) { + const auto state_idx = sdata.sel->get_index(assign_sel[i]); + auto &state = *states[state_idx]; + if (!LAST && state.is_set) { + continue; + } + + const auto idx = idata.sel->get_index(assign_sel[i]); + bool is_null = !idata.validity.RowIsValid(idx); + FirstFunctionStringBase::template SetValue(state, input_data, sort_key_data[i], + is_null); } } template static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.value) { + if (!state.is_set || state.is_null) { finalize_data.ReturnNull(); } else { - VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); + CreateSortKeyHelpers::DecodeSortKey(state.value, finalize_data.result, finalize_data.result_idx, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); } } @@ -229,45 +235,44 @@ AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { return GetFirstFunction(LogicalType::HUGEINT); } } - template static AggregateFunction GetFirstFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return GetFirstAggregateTemplated(type); - case LogicalTypeId::TINYINT: + if (type.id() == LogicalTypeId::DECIMAL) { + type.Verify(); + AggregateFunction function = GetDecimalFirstFunction(type); + function.arguments[0] = type; + function.return_type = type; + return function; + } + switch (type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: return GetFirstAggregateTemplated(type); - case LogicalTypeId::SMALLINT: + case PhysicalType::INT16: return GetFirstAggregateTemplated(type); - case LogicalTypeId::INTEGER: - case LogicalTypeId::DATE: + case PhysicalType::INT32: return GetFirstAggregateTemplated(type); - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: + case PhysicalType::INT64: return GetFirstAggregateTemplated(type); - case LogicalTypeId::UTINYINT: + case PhysicalType::UINT8: return GetFirstAggregateTemplated(type); - case LogicalTypeId::USMALLINT: + case PhysicalType::UINT16: return GetFirstAggregateTemplated(type); - case LogicalTypeId::UINTEGER: + case PhysicalType::UINT32: return GetFirstAggregateTemplated(type); - case LogicalTypeId::UBIGINT: + case PhysicalType::UINT64: return GetFirstAggregateTemplated(type); - case LogicalTypeId::HUGEINT: + case PhysicalType::INT128: return GetFirstAggregateTemplated(type); - case LogicalTypeId::UHUGEINT: + case PhysicalType::UINT128: return GetFirstAggregateTemplated(type); - case LogicalTypeId::FLOAT: + case PhysicalType::FLOAT: return GetFirstAggregateTemplated(type); - case LogicalTypeId::DOUBLE: + case PhysicalType::DOUBLE: return GetFirstAggregateTemplated(type); - case LogicalTypeId::INTERVAL: + case PhysicalType::INTERVAL: return GetFirstAggregateTemplated(type); - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: + case PhysicalType::VARCHAR: if (LAST) { return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, FirstFunctionString>(type, type); @@ -275,21 +280,13 @@ static AggregateFunction GetFirstFunction(const LogicalType &type) { return AggregateFunction::UnaryAggregate, string_t, string_t, FirstFunctionString>(type, type); } - case LogicalTypeId::DECIMAL: { - type.Verify(); - AggregateFunction function = GetDecimalFirstFunction(type); - function.arguments[0] = type; - function.return_type = type; - // TODO set_key here? - return function; - } default: { using OP = FirstVectorFunction; - return AggregateFunction({type}, type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, OP::Update, - AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, - AggregateFunction::StateDestroy, nullptr, nullptr); + using STATE = FirstState; + return AggregateFunction( + {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + OP::Update, AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, + nullptr, OP::Bind, LAST ? AggregateFunction::StateDestroy : nullptr, nullptr, nullptr); } } } diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index 78964dca..4e86f930 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -560,7 +560,8 @@ struct SortedAggregateFunction { sliced.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); // Reusable inner state - vector agg_state(order_bind.function.state_size()); + auto &aggr = order_bind.function; + vector agg_state(aggr.state_size(aggr)); Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data()))); // State variables @@ -568,11 +569,11 @@ struct SortedAggregateFunction { AggregateInputData aggr_bind_info(bind_info, aggr_input_data.allocator); // Inner aggregate APIs - auto initialize = order_bind.function.initialize; - auto destructor = order_bind.function.destructor; - auto simple_update = order_bind.function.simple_update; - auto update = order_bind.function.update; - auto finalize = order_bind.function.finalize; + auto initialize = aggr.initialize; + auto destructor = aggr.destructor; + auto simple_update = aggr.simple_update; + auto update = aggr.update; + auto finalize = aggr.finalize; auto sdata = FlatVector::GetData(states); @@ -631,7 +632,7 @@ struct SortedAggregateFunction { } auto scanner = make_uniq(*global_sort); - initialize(agg_state.data()); + initialize(aggr, agg_state.data()); while (scanner->Remaining()) { chunk.Reset(); scanner->Scan(chunk); @@ -648,7 +649,7 @@ struct SortedAggregateFunction { destructor(agg_state_vec, aggr_bind_info, 1); } - initialize(agg_state.data()); + initialize(aggr, agg_state.data()); } const auto input_count = MinValue(state_unprocessed[sorted], chunk.size() - consumed); for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { @@ -694,7 +695,7 @@ struct SortedAggregateFunction { } for (; sorted < count; ++sorted) { - initialize(agg_state.data()); + initialize(aggr, agg_state.data()); // Finalize a single value at the next offset agg_state_vec.SetVectorType(states.GetVectorType()); diff --git a/src/duckdb/src/function/aggregate_function.cpp b/src/duckdb/src/function/aggregate_function.cpp new file mode 100644 index 00000000..dd3bc001 --- /dev/null +++ b/src/duckdb/src/function/aggregate_function.cpp @@ -0,0 +1,8 @@ +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +AggregateFunctionInfo::~AggregateFunctionInfo() { +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp index 5e30e4c3..48b8bef7 100644 --- a/src/duckdb/src/function/cast/cast_function_set.cpp +++ b/src/duckdb/src/function/cast/cast_function_set.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/pair.hpp" #include "duckdb/common/types/type_map.hpp" #include "duckdb/function/cast_rules.hpp" +#include "duckdb/planner/collation_binding.hpp" #include "duckdb/main/config.hpp" namespace duckdb { @@ -34,10 +35,18 @@ CastFunctionSet &CastFunctionSet::Get(ClientContext &context) { return DBConfig::GetConfig(context).GetCastFunctions(); } +CollationBinding &CollationBinding::Get(ClientContext &context) { + return DBConfig::GetConfig(context).GetCollationBinding(); +} + CastFunctionSet &CastFunctionSet::Get(DatabaseInstance &db) { return DBConfig::GetConfig(db).GetCastFunctions(); } +CollationBinding &CollationBinding::Get(DatabaseInstance &db) { + return DBConfig::GetConfig(db).GetCollationBinding(); +} + BoundCastInfo CastFunctionSet::GetCastFunction(const LogicalType &source, const LogicalType &target, GetCastFunctionInput &get_input) { if (source == target) { @@ -97,7 +106,7 @@ static auto RelaxedTypeMatch(type_map_t &map, const LogicalType case LogicalTypeId::UNION: return map.find(LogicalType::UNION({{"any", LogicalType::ANY}})); case LogicalTypeId::ARRAY: - return map.find(LogicalType::ARRAY(LogicalType::ANY)); + return map.find(LogicalType::ARRAY(LogicalType::ANY, optional_idx())); default: return map.find(LogicalType::ANY); } diff --git a/src/duckdb/src/function/cast/decimal_cast.cpp b/src/duckdb/src/function/cast/decimal_cast.cpp index 3af81e0f..d1bc6e77 100644 --- a/src/duckdb/src/function/cast/decimal_cast.cpp +++ b/src/duckdb/src/function/cast/decimal_cast.cpp @@ -100,8 +100,17 @@ bool TemplatedDecimalScaleUp(Vector &source, Vector &result, idx_t count, CastPa struct DecimalScaleDownOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { + // We need to round here, not truncate. auto data = (DecimalScaleInput *)dataptr; - return Cast::Operation(input / data->factor); + // Scale first so we don't overflow when rounding. + const auto scaling = data->factor / 2; + input /= scaling; + if (input < 0) { + input -= 1; + } else { + input += 1; + } + return Cast::Operation(input / 2); } }; diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 7769e355..d0da4432 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -143,6 +143,8 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return EnumCastSwitch(input, source, target); case LogicalTypeId::ARRAY: return ArrayCastSwitch(input, source, target); + case LogicalTypeId::VARINT: + return VarintCastSwitch(input, source, target); case LogicalTypeId::AGGREGATE_STATE: return AggregateStateToBlobCast; default: diff --git a/src/duckdb/src/function/cast/numeric_casts.cpp b/src/duckdb/src/function/cast/numeric_casts.cpp index 0438e6e1..bdb999ff 100644 --- a/src/duckdb/src/function/cast/numeric_casts.cpp +++ b/src/duckdb/src/function/cast/numeric_casts.cpp @@ -2,6 +2,7 @@ #include "duckdb/function/cast/vector_cast_helpers.hpp" #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/common/operator/numeric_cast.hpp" +#include "duckdb/common/types/varint.hpp" namespace duckdb { @@ -41,6 +42,8 @@ static BoundCastInfo InternalNumericCastSwitch(const LogicalType &source, const return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::BIT: return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::VARINT: + return Varint::NumericToVarintCastSwitch(source); default: return DefaultCasts::TryVectorNullCast; } diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 4645dd34..9f8b5ee2 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/vector.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/function/cast/bound_cast_data.hpp" +#include "duckdb/common/types/varint.hpp" namespace duckdb { @@ -477,7 +478,7 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); case LogicalTypeId::TIMESTAMP_NS: return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); + &VectorCastHelpers::TryCastStrictLoop); case LogicalTypeId::TIMESTAMP_SEC: return BoundCastInfo( &VectorCastHelpers::TryCastStrictLoop); @@ -502,10 +503,10 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical ListBoundCastData::InitListLocalState); case LogicalTypeId::ARRAY: // the second argument allows for a secondary casting function to be passed in the CastParameters - return BoundCastInfo( - &StringToNestedTypeCast, - ArrayBoundCastData::BindArrayToArrayCast(input, LogicalType::ARRAY(LogicalType::VARCHAR), target), - ArrayBoundCastData::InitArrayLocalState); + return BoundCastInfo(&StringToNestedTypeCast, + ArrayBoundCastData::BindArrayToArrayCast( + input, LogicalType::ARRAY(LogicalType::VARCHAR, optional_idx()), target), + ArrayBoundCastData::InitArrayLocalState); case LogicalTypeId::STRUCT: return BoundCastInfo(&StringToNestedTypeCast, StructBoundCastData::BindStructToStructCast(input, InitVarcharStructType(target), target), @@ -515,6 +516,8 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical MapBoundCastData::BindMapToMapCast( input, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), target), InitMapCastLocalState); + case LogicalTypeId::VARINT: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); default: return VectorStringCastNumericSwitch(input, source, target); } diff --git a/src/duckdb/src/function/cast/time_casts.cpp b/src/duckdb/src/function/cast/time_casts.cpp index 4ac38d12..b9587ad1 100644 --- a/src/duckdb/src/function/cast/time_casts.cpp +++ b/src/duckdb/src/function/cast/time_casts.cpp @@ -14,7 +14,7 @@ BoundCastInfo DefaultCasts::DateCastSwitch(BindCastInput &input, const LogicalTy // date to timestamp return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_SEC: return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_MS: @@ -113,7 +113,7 @@ BoundCastInfo DefaultCasts::TimestampNsCastSwitch(BindCastInput &input, const Lo switch (target.id()) { case LogicalTypeId::VARCHAR: // timestamp (ns) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DATE: // timestamp (ns) to date return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp index 10b79f01..72b17174 100644 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -171,7 +171,7 @@ unique_ptr BindUnionToUnionCast(BindCastInput &input, const Logic auto &target_member_name = UnionType::GetMemberName(target, target_idx); // found a matching member - if (source_member_name == target_member_name) { + if (StringUtil::CIEquals(source_member_name, target_member_name)) { auto &target_member_type = UnionType::GetMemberType(target, target_idx); tag_map[source_idx] = target_idx; member_casts.push_back(input.GetCastFunction(source_member_type, target_member_type)); diff --git a/src/duckdb/src/function/cast/varint_casts.cpp b/src/duckdb/src/function/cast/varint_casts.cpp new file mode 100644 index 00000000..0f434688 --- /dev/null +++ b/src/duckdb/src/function/cast/varint_casts.cpp @@ -0,0 +1,283 @@ +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" +#include "duckdb/common/types/varint.hpp" +#include + +namespace duckdb { + +template +string_t IntToVarInt(Vector &result, T int_value) { + // Determine if the number is negative + bool is_negative = int_value < 0; + // Determine the number of data bytes + uint64_t abs_value; + if (is_negative) { + if (int_value == std::numeric_limits::min()) { + abs_value = static_cast(std::numeric_limits::max()) + 1; + } else { + abs_value = static_cast(std::abs(static_cast(int_value))); + } + } else { + abs_value = static_cast(int_value); + } + uint32_t data_byte_size; + if (abs_value != NumericLimits::Maximum()) { + data_byte_size = (abs_value == 0) ? 1 : static_cast(std::ceil(std::log2(abs_value + 1) / 8.0)); + } else { + data_byte_size = static_cast(std::ceil(std::log2(abs_value) / 8.0)); + } + + uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + auto blob = StringVector::EmptyString(result, blob_size); + auto writable_blob = blob.GetDataWriteable(); + Varint::SetHeader(writable_blob, data_byte_size, is_negative); + + // Add data bytes to the blob, starting off after header bytes + idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + for (int i = static_cast(data_byte_size) - 1; i >= 0; --i) { + if (is_negative) { + writable_blob[wb_idx++] = static_cast(~(abs_value >> i * 8 & 0xFF)); + } else { + writable_blob[wb_idx++] = static_cast(abs_value >> i * 8 & 0xFF); + } + } + blob.Finalize(); + return blob; +} + +template <> +string_t HugeintCastToVarInt::Operation(uhugeint_t int_value, Vector &result) { + uint32_t data_byte_size; + if (int_value.upper != NumericLimits::Maximum()) { + data_byte_size = + (int_value.upper == 0) ? 0 : static_cast(std::ceil(std::log2(int_value.upper + 1) / 8.0)); + } else { + data_byte_size = static_cast(std::ceil(std::log2(int_value.upper) / 8.0)); + } + + uint32_t upper_byte_size = data_byte_size; + if (data_byte_size > 0) { + // If we have at least one byte on the upper side, the bottom side is complete + data_byte_size += 8; + } else { + if (int_value.lower != NumericLimits::Maximum()) { + data_byte_size += static_cast(std::ceil(std::log2(int_value.lower + 1) / 8.0)); + } else { + data_byte_size += static_cast(std::ceil(std::log2(int_value.lower) / 8.0)); + } + } + if (data_byte_size == 0) { + data_byte_size++; + } + uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + auto blob = StringVector::EmptyString(result, blob_size); + auto writable_blob = blob.GetDataWriteable(); + Varint::SetHeader(writable_blob, data_byte_size, false); + + // Add data bytes to the blob, starting off after header bytes + idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { + writable_blob[wb_idx++] = static_cast(int_value.upper >> i * 8 & 0xFF); + } + for (int i = static_cast(data_byte_size - upper_byte_size) - 1; i >= 0; --i) { + writable_blob[wb_idx++] = static_cast(int_value.lower >> i * 8 & 0xFF); + } + blob.Finalize(); + return blob; +} + +template <> +string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { + // Determine if the number is negative + bool is_negative = int_value.upper >> 63 & 1; + if (is_negative) { + // We must check if it's -170141183460469231731687303715884105728, since it's not possible to negate it + // without overflowing + if (int_value == NumericLimits::Minimum()) { + uhugeint_t u_int_value {0x8000000000000000, 0}; + auto cast_value = Operation(u_int_value, result); + // We have to do all the bit flipping. + auto writable_value_ptr = cast_value.GetDataWriteable(); + Varint::SetHeader(writable_value_ptr, cast_value.GetSize() - Varint::VARINT_HEADER_SIZE, is_negative); + for (idx_t i = Varint::VARINT_HEADER_SIZE; i < cast_value.GetSize(); i++) { + writable_value_ptr[i] = static_cast(~writable_value_ptr[i]); + } + cast_value.Finalize(); + return cast_value; + } + int_value = -int_value; + } + // Determine the number of data bytes + uint64_t abs_value_upper = static_cast(int_value.upper); + + uint32_t data_byte_size; + if (abs_value_upper != NumericLimits::Maximum()) { + data_byte_size = + (abs_value_upper == 0) ? 0 : static_cast(std::ceil(std::log2(abs_value_upper + 1) / 8.0)); + } else { + data_byte_size = static_cast(std::ceil(std::log2(abs_value_upper) / 8.0)); + } + + uint32_t upper_byte_size = data_byte_size; + if (data_byte_size > 0) { + // If we have at least one byte on the upper side, the bottom side is complete + data_byte_size += 8; + } else { + if (int_value.lower != NumericLimits::Maximum()) { + data_byte_size += static_cast(std::ceil(std::log2(int_value.lower + 1) / 8.0)); + } else { + data_byte_size += static_cast(std::ceil(std::log2(int_value.lower) / 8.0)); + } + } + + if (data_byte_size == 0) { + data_byte_size++; + } + uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + auto blob = StringVector::EmptyString(result, blob_size); + auto writable_blob = blob.GetDataWriteable(); + Varint::SetHeader(writable_blob, data_byte_size, is_negative); + + // Add data bytes to the blob, starting off after header bytes + idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { + if (is_negative) { + writable_blob[wb_idx++] = static_cast(~(abs_value_upper >> i * 8 & 0xFF)); + } else { + writable_blob[wb_idx++] = static_cast(abs_value_upper >> i * 8 & 0xFF); + } + } + for (int i = static_cast(data_byte_size - upper_byte_size) - 1; i >= 0; --i) { + if (is_negative) { + writable_blob[wb_idx++] = static_cast(~(int_value.lower >> i * 8 & 0xFF)); + } else { + writable_blob[wb_idx++] = static_cast(int_value.lower >> i * 8 & 0xFF); + } + } + blob.Finalize(); + return blob; +} + +// Varchar to Varint +// TODO: This is a slow quadratic algorithm, we can still optimize it further. +template <> +bool TryCastToVarInt::Operation(string_t input_value, string_t &result_value, Vector &result, + CastParameters ¶meters) { + auto blob_string = Varint::VarcharToVarInt(input_value); + + uint32_t blob_size = static_cast(blob_string.size()); + result_value = StringVector::EmptyString(result, blob_size); + auto writable_blob = result_value.GetDataWriteable(); + + // Write string_blob into blob + for (idx_t i = 0; i < blob_string.size(); i++) { + writable_blob[i] = blob_string[i]; + } + result_value.Finalize(); + return true; +} + +template +bool DoubleToVarInt(T double_value, string_t &result_value, Vector &result) { + // Check if we can cast it + if (!std::isfinite(double_value)) { + // We can't cast inf -inf nan + return false; + } + // Determine if the number is negative + bool is_negative = double_value < 0; + // Determine the number of data bytes + double abs_value = std::abs(double_value); + + if (abs_value == 0) { + // Return Value 0 + result_value = Varint::InitializeVarintZero(result); + return true; + } + vector value; + while (abs_value > 0) { + double quotient = abs_value / 256; + double truncated = floor(quotient); + uint8_t byte = static_cast(abs_value - truncated * 256); + abs_value = truncated; + if (is_negative) { + value.push_back(static_cast(~byte)); + } else { + value.push_back(static_cast(byte)); + } + } + uint32_t data_byte_size = static_cast(value.size()); + uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + result_value = StringVector::EmptyString(result, blob_size); + auto writable_blob = result_value.GetDataWriteable(); + Varint::SetHeader(writable_blob, data_byte_size, is_negative); + // Add data bytes to the blob, starting off after header bytes + idx_t blob_string_idx = value.size() - 1; + for (idx_t i = Varint::VARINT_HEADER_SIZE; i < blob_size; i++) { + writable_blob[i] = value[blob_string_idx--]; + } + result_value.Finalize(); + return true; +} + +template <> +bool TryCastToVarInt::Operation(double double_value, string_t &result_value, Vector &result, + CastParameters ¶meters) { + return DoubleToVarInt(double_value, result_value, result); +} + +template <> +bool TryCastToVarInt::Operation(float double_value, string_t &result_value, Vector &result, + CastParameters ¶meters) { + return DoubleToVarInt(double_value, result_value, result); +} + +BoundCastInfo Varint::NumericToVarintCastSwitch(const LogicalType &source) { + // now switch on the result type + switch (source.id()) { + case LogicalTypeId::TINYINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::UTINYINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::SMALLINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::USMALLINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::INTEGER: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::UINTEGER: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::BIGINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::UBIGINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::UHUGEINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::FLOAT: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::DOUBLE: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::HUGEINT: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::DECIMAL: + default: + return DefaultCasts::TryVectorNullCast; + } +} + +BoundCastInfo DefaultCasts::VarintCastSwitch(BindCastInput &input, const LogicalType &source, + const LogicalType &target) { + D_ASSERT(source.id() == LogicalTypeId::VARINT); + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + return BoundCastInfo(&VectorCastHelpers::StringCast); + case LogicalTypeId::DOUBLE: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/vector_cast_helpers.cpp b/src/duckdb/src/function/cast/vector_cast_helpers.cpp index a064b4ac..c7aa523e 100644 --- a/src/duckdb/src/function/cast/vector_cast_helpers.cpp +++ b/src/duckdb/src/function/cast/vector_cast_helpers.cpp @@ -4,7 +4,9 @@ namespace duckdb { // ------- Helper functions for splitting string nested types ------- static bool IsNull(const char *buf, idx_t start_pos, Vector &child, idx_t row_idx) { - if (buf[start_pos] == 'N' && buf[start_pos + 1] == 'U' && buf[start_pos + 2] == 'L' && buf[start_pos + 3] == 'L') { + if ((buf[start_pos] == 'N' || buf[start_pos] == 'n') && (buf[start_pos + 1] == 'U' || buf[start_pos + 1] == 'u') && + (buf[start_pos + 2] == 'L' || buf[start_pos + 2] == 'l') && + (buf[start_pos + 3] == 'L' || buf[start_pos + 3] == 'l')) { FlatVector::SetNull(child, row_idx, true); return true; } diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp index acbe2eb0..951ecc93 100644 --- a/src/duckdb/src/function/cast_rules.cpp +++ b/src/duckdb/src/function/cast_rules.cpp @@ -1,5 +1,7 @@ #include "duckdb/function/cast_rules.hpp" +#include "duckdb/common/helper.hpp" #include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/case_insensitive_map.hpp" namespace duckdb { @@ -212,6 +214,9 @@ static int64_t ImplicitCastDate(const LogicalType &to) { switch (to.id()) { case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_SEC: return TargetTypeCost(to); default: return -1; @@ -269,6 +274,15 @@ static int64_t ImplicitCastTimestamp(const LogicalType &to) { } } +static int64_t ImplicitCastVarint(const LogicalType &to) { + switch (to.id()) { + case LogicalTypeId::DOUBLE: + return TargetTypeCost(to); + default: + return -1; + } +} + bool LogicalTypeIsValid(const LogicalType &type) { switch (type.id()) { case LogicalTypeId::STRUCT: @@ -400,15 +414,17 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) // TODO: if we can access the expression we could resolve the size if the list is constant. return ImplicitCast(ListType::GetChildType(from), ArrayType::GetChildType(to)); } - if (from.id() == to.id()) { - // arguments match: do nothing - return 0; - } - if (from.id() == LogicalTypeId::UNION && to.id() == LogicalTypeId::UNION) { + // Check that the target union type is fully resolved. + if (to.AuxInfo() == nullptr) { + // If not, try anyway and let the actual cast logic handle it. + // This is to allow passing unions into functions that take a generic union type (without specifying member + // types) as an argument. + return 0; + } // Unions can be cast if the source tags are a subset of the target tags // in which case the most expensive cost is used - int cost = -1; + int64_t cost = -1; for (idx_t from_member_idx = 0; from_member_idx < UnionType::GetMemberCount(from); from_member_idx++) { auto &from_member_name = UnionType::GetMemberName(from, from_member_idx); @@ -416,14 +432,12 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(to); to_member_idx++) { auto &to_member_name = UnionType::GetMemberName(to, to_member_idx); - if (from_member_name == to_member_name) { + if (StringUtil::CIEquals(from_member_name, to_member_name)) { auto &from_member_type = UnionType::GetMemberType(from, from_member_idx); auto &to_member_type = UnionType::GetMemberType(to, to_member_idx); - int child_cost = NumericCast(ImplicitCast(from_member_type, to_member_type)); - if (child_cost > cost) { - cost = child_cost; - } + auto child_cost = ImplicitCast(from_member_type, to_member_type); + cost = MaxValue(cost, child_cost); found = true; break; } @@ -434,19 +448,92 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) } return cost; } + if (from.id() == LogicalTypeId::STRUCT && to.id() == LogicalTypeId::STRUCT) { + if (to.AuxInfo() == nullptr) { + // If this struct is not fully resolved, we'll leave it to the actual cast logic to handle it. + return 0; + } + + auto &source_children = StructType::GetChildTypes(from); + auto &target_children = StructType::GetChildTypes(to); + + if (source_children.size() != target_children.size()) { + // different number of children: not possible + return -1; + } + + auto target_is_unnamed = StructType::IsUnnamed(to); + auto source_is_unnamed = StructType::IsUnnamed(from); + auto named_struct_cast = !source_is_unnamed && !target_is_unnamed; + + int64_t cost = -1; + if (named_struct_cast) { + + // Collect the target members in a map for easy lookup + case_insensitive_map_t target_members; + for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { + auto &target_name = target_children[target_idx].first; + if (target_members.find(target_name) != target_members.end()) { + // duplicate name in target struct + return -1; + } + target_members[target_name] = target_idx; + } + // Match the source members to the target members by name + for (idx_t source_idx = 0; source_idx < source_children.size(); source_idx++) { + auto &source_child = source_children[source_idx]; + auto entry = target_members.find(source_child.first); + if (entry == target_members.end()) { + // element in source struct was not found in target struct + return -1; + } + auto target_idx = entry->second; + target_members.erase(entry); + auto child_cost = ImplicitCast(source_child.second, target_children[target_idx].second); + if (child_cost == -1) { + return -1; + } + cost = MaxValue(cost, child_cost); + } + } else { + // Match the source members to the target members by position + for (idx_t i = 0; i < source_children.size(); i++) { + auto &source_child = source_children[i]; + auto &target_child = target_children[i]; + auto child_cost = ImplicitCast(source_child.second, target_child.second); + if (child_cost == -1) { + return -1; + } + cost = MaxValue(cost, child_cost); + } + } + return cost; + } + + if (from.id() == to.id()) { + // arguments match: do nothing + return 0; + } + // Special case: Anything can be cast to a union if the source type is a member of the union if (to.id() == LogicalTypeId::UNION) { // check that the union type is fully resolved. if (to.AuxInfo() == nullptr) { return -1; } - // every type can be implicitly be cast to a union if the source type is a member of the union + // check if the union contains something castable from the source type + // in which case the least expensive (most specific) cast should be used + bool found = false; + auto cost = NumericLimits::Maximum(); for (idx_t i = 0; i < UnionType::GetMemberCount(to); i++) { - auto member = UnionType::GetMemberType(to, i); - if (from == member) { - return 0; + auto target_member = UnionType::GetMemberType(to, i); + auto target_cost = ImplicitCast(from, target_member); + if (target_cost != -1) { + found = true; + cost = MinValue(cost, target_cost); } } + return found ? cost : -1; } switch (from.id()) { @@ -488,6 +575,8 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) return ImplicitCastTimestampNS(to); case LogicalTypeId::TIMESTAMP: return ImplicitCastTimestamp(to); + case LogicalTypeId::VARINT: + return ImplicitCastVarint(to); default: return -1; } diff --git a/src/duckdb/src/function/compression_config.cpp b/src/duckdb/src/function/compression_config.cpp index 62ba1ce6..a5e686f3 100644 --- a/src/duckdb/src/function/compression_config.cpp +++ b/src/duckdb/src/function/compression_config.cpp @@ -1,12 +1,12 @@ -#include "duckdb/main/config.hpp" -#include "duckdb/function/compression_function.hpp" -#include "duckdb/function/compression/compression.hpp" #include "duckdb/common/pair.hpp" +#include "duckdb/function/compression/compression.hpp" +#include "duckdb/function/compression_function.hpp" +#include "duckdb/main/config.hpp" namespace duckdb { typedef CompressionFunction (*get_compression_function_t)(PhysicalType type); -typedef bool (*compression_supports_type_t)(PhysicalType type); +typedef bool (*compression_supports_type_t)(const PhysicalType physical_type); struct DefaultCompressionMethod { CompressionType type; @@ -29,12 +29,12 @@ static const DefaultCompressionMethod internal_compression_methods[] = { {CompressionType::COMPRESSION_AUTO, nullptr, nullptr}}; static optional_ptr FindCompressionFunction(CompressionFunctionSet &set, CompressionType type, - PhysicalType data_type) { + const PhysicalType physical_type) { auto &functions = set.functions; auto comp_entry = functions.find(type); if (comp_entry != functions.end()) { auto &type_functions = comp_entry->second; - auto type_entry = type_functions.find(data_type); + auto type_entry = type_functions.find(physical_type); if (type_entry != type_functions.end()) { return &type_entry->second; } @@ -43,56 +43,58 @@ static optional_ptr FindCompressionFunction(CompressionFunc } static optional_ptr LoadCompressionFunction(CompressionFunctionSet &set, CompressionType type, - PhysicalType data_type) { - for (idx_t index = 0; internal_compression_methods[index].get_function; index++) { - const auto &method = internal_compression_methods[index]; + const PhysicalType physical_type) { + for (idx_t i = 0; internal_compression_methods[i].get_function; i++) { + const auto &method = internal_compression_methods[i]; if (method.type == type) { - // found the correct compression type - if (!method.supports_type(data_type)) { - // but it does not support this data type: bail out + if (!method.supports_type(physical_type)) { return nullptr; } - // the type is supported: create the function and insert it into the set - auto function = method.get_function(data_type); - set.functions[type].insert(make_pair(data_type, function)); - return FindCompressionFunction(set, type, data_type); + // The type is supported. We create the function and insert it into the set of available functions. + auto function = method.get_function(physical_type); + set.functions[type].insert(make_pair(physical_type, function)); + return FindCompressionFunction(set, type, physical_type); } } throw InternalException("Unsupported compression function type"); } static void TryLoadCompression(DBConfig &config, vector> &result, CompressionType type, - PhysicalType data_type) { - auto function = config.GetCompressionFunction(type, data_type); + const PhysicalType physical_type) { + auto function = config.GetCompressionFunction(type, physical_type); if (!function) { return; } result.push_back(*function); } -vector> DBConfig::GetCompressionFunctions(PhysicalType data_type) { +vector> DBConfig::GetCompressionFunctions(const PhysicalType physical_type) { vector> result; - TryLoadCompression(*this, result, CompressionType::COMPRESSION_UNCOMPRESSED, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_RLE, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_BITPACKING, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_DICTIONARY, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_CHIMP, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_PATAS, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALP, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALPRD, data_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_FSST, data_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_UNCOMPRESSED, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_RLE, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_BITPACKING, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_DICTIONARY, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_CHIMP, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_PATAS, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALP, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALPRD, physical_type); + TryLoadCompression(*this, result, CompressionType::COMPRESSION_FSST, physical_type); return result; } -optional_ptr DBConfig::GetCompressionFunction(CompressionType type, PhysicalType data_type) { +optional_ptr DBConfig::GetCompressionFunction(CompressionType type, + const PhysicalType physical_type) { lock_guard l(compression_functions->lock); - // check if the function is already loaded - auto function = FindCompressionFunction(*compression_functions, type, data_type); + + // Check if the function is already loaded into the global compression functions. + auto function = FindCompressionFunction(*compression_functions, type, physical_type); if (function) { return function; } - // else load the function - return LoadCompressionFunction(*compression_functions, type, data_type); + + // We could not find the function in the global compression functions, + // so we attempt loading it. + return LoadCompressionFunction(*compression_functions, type, physical_type); } } // namespace duckdb diff --git a/src/duckdb/src/function/copy_function.cpp b/src/duckdb/src/function/copy_function.cpp new file mode 100644 index 00000000..ac2bc754 --- /dev/null +++ b/src/duckdb/src/function/copy_function.cpp @@ -0,0 +1,27 @@ +#include "duckdb/function/copy_function.hpp" + +namespace duckdb { + +vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type) { + switch (return_type) { + case CopyFunctionReturnType::CHANGED_ROWS: + return {"Count"}; + case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: + return {"Count", "Files"}; + default: + throw NotImplementedException("Unknown CopyFunctionReturnType"); + } +} + +vector GetCopyFunctionReturnLogicalTypes(CopyFunctionReturnType return_type) { + switch (return_type) { + case CopyFunctionReturnType::CHANGED_ROWS: + return {LogicalType::BIGINT}; + case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: + return {LogicalType::BIGINT, LogicalType::LIST(LogicalType::VARCHAR)}; + default: + throw NotImplementedException("Unknown CopyFunctionReturnType"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index 1ddf7bbd..6bc8dcfa 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -12,6 +12,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/scalar/generic_functions.hpp" namespace duckdb { @@ -238,21 +239,38 @@ LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const L return LogicalTypeComparisonResult::DIFFERENT_TYPES; } -LogicalType PrepareTypeForCast(const LogicalType &type) { +bool TypeRequiresPrepare(const LogicalType &type) { + if (type.id() == LogicalTypeId::ANY) { + return true; + } + if (type.id() == LogicalTypeId::LIST) { + return TypeRequiresPrepare(ListType::GetChildType(type)); + } + return false; +} + +LogicalType PrepareTypeForCastRecursive(const LogicalType &type) { if (type.id() == LogicalTypeId::ANY) { return AnyType::GetTargetType(type); } if (type.id() == LogicalTypeId::LIST) { - return LogicalType::LIST(PrepareTypeForCast(ListType::GetChildType(type))); + return LogicalType::LIST(PrepareTypeForCastRecursive(ListType::GetChildType(type))); } return type; } +void PrepareTypeForCast(LogicalType &type) { + if (!TypeRequiresPrepare(type)) { + return; + } + type = PrepareTypeForCastRecursive(type); +} + void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector> &children) { for (auto &arg : function.arguments) { - arg = PrepareTypeForCast(arg); + PrepareTypeForCast(arg); } - function.varargs = PrepareTypeForCast(function.varargs); + PrepareTypeForCast(function.varargs); for (idx_t i = 0; i < children.size(); i++) { auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs; @@ -338,25 +356,35 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE return BindScalarFunction(bound_function, std::move(children), is_operator, binder); } -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, - vector> children, - bool is_operator, optional_ptr binder) { +unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, + vector> children, bool is_operator, + optional_ptr binder) { unique_ptr bind_info; if (bound_function.bind) { bind_info = bound_function.bind(context, bound_function, children); } if (bound_function.get_modified_databases && binder) { auto &properties = binder->GetStatementProperties(); - FunctionModifiedDatabasesInput input(bind_info, properties.modified_databases); - bound_function.get_modified_databases(input); + FunctionModifiedDatabasesInput input(bind_info, properties); + bound_function.get_modified_databases(context, input); } // check if we need to add casts to the children CastToFunctionArguments(bound_function, children); // now create the function auto return_type = bound_function.return_type; - return make_uniq(std::move(return_type), std::move(bound_function), std::move(children), - std::move(bind_info), is_operator); + unique_ptr result; + auto result_func = make_uniq(std::move(return_type), std::move(bound_function), + std::move(children), std::move(bind_info), is_operator); + if (result_func->function.bind_expression) { + // if a bind_expression callback is registered - call it and emit the resulting expression + FunctionBindExpressionInput input(context, result_func->bind_info.get(), *result_func); + result = result_func->function.bind_expression(input); + } + if (!result) { + result = std::move(result_func); + } + return result; } unique_ptr FunctionBinder::BindAggregateFunction(AggregateFunction bound_function, diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp index ad92ee9f..cb9313b0 100644 --- a/src/duckdb/src/function/macro_function.cpp +++ b/src/duckdb/src/function/macro_function.cpp @@ -12,51 +12,96 @@ namespace duckdb { -// MacroFunction::MacroFunction(unique_ptr expression) : expression(std::move(expression)) {} - MacroFunction::MacroFunction(MacroType type) : type(type) { } -string MacroFunction::ValidateArguments(MacroFunction ¯o_def, const string &name, FunctionExpression &function_expr, - vector> &positionals, - unordered_map> &defaults) { +string FormatMacroFunction(MacroFunction &function, const string &name) { + string result; + result = name + "("; + string parameters; + for (auto ¶m : function.parameters) { + if (!parameters.empty()) { + parameters += ", "; + } + parameters += param->Cast().GetColumnName(); + } + for (auto &named_param : function.default_parameters) { + if (!parameters.empty()) { + parameters += ", "; + } + parameters += named_param.first; + parameters += " := "; + parameters += named_param.second->ToString(); + } + result += parameters + ")"; + return result; +} +MacroBindResult MacroFunction::BindMacroFunction(const vector> &functions, const string &name, + FunctionExpression &function_expr, + vector> &positionals, + unordered_map> &defaults) { // separate positional and default arguments for (auto &arg : function_expr.children) { if (!arg->alias.empty()) { // default argument - if (!macro_def.default_parameters.count(arg->alias)) { - return StringUtil::Format("Macro %s does not have default parameter %s!", name, arg->alias); - } else if (defaults.count(arg->alias)) { - return StringUtil::Format("Duplicate default parameters %s!", arg->alias); + if (defaults.count(arg->alias)) { + return MacroBindResult(StringUtil::Format("Duplicate default parameters %s!", arg->alias)); } defaults[arg->alias] = std::move(arg); } else if (!defaults.empty()) { - return "Positional parameters cannot come after parameters with a default value!"; + return MacroBindResult("Positional parameters cannot come after parameters with a default value!"); } else { // positional argument positionals.push_back(std::move(arg)); } } - // validate if the right number of arguments was supplied - string error; - auto ¶meters = macro_def.parameters; - if (parameters.size() != positionals.size()) { - error = StringUtil::Format( - "Macro function '%s(%s)' requires ", name, - StringUtil::Join(parameters, parameters.size(), ", ", [](const unique_ptr &p) { - return (p->Cast()).column_names[0]; - })); - error += parameters.size() == 1 ? "a single positional argument" - : StringUtil::Format("%i positional arguments", parameters.size()); - error += ", but "; - error += positionals.size() == 1 ? "a single positional argument was" - : StringUtil::Format("%i positional arguments were", positionals.size()); - error += " provided."; - return error; + // check for each macro function if it matches the number of positional arguments + optional_idx result_idx; + for (idx_t function_idx = 0; function_idx < functions.size(); function_idx++) { + if (functions[function_idx]->parameters.size() == positionals.size()) { + // found a matching function + result_idx = function_idx; + break; + } + } + if (!result_idx.IsValid()) { + // no matching function found + string error; + if (functions.size() == 1) { + // we only have one function - print the old more detailed error message + auto ¯o_def = *functions[0]; + auto ¶meters = macro_def.parameters; + error = StringUtil::Format("Macro function %s requires ", FormatMacroFunction(macro_def, name)); + error += parameters.size() == 1 ? "a single positional argument" + : StringUtil::Format("%i positional arguments", parameters.size()); + error += ", but "; + error += positionals.size() == 1 ? "a single positional argument was" + : StringUtil::Format("%i positional arguments were", positionals.size()); + error += " provided."; + } else { + // we have multiple functions - list all candidates + error += StringUtil::Format("Macro \"%s\" does not support %llu parameters.\n", name, positionals.size()); + error += "Candidate macros:"; + for (auto &function : functions) { + error += "\n\t" + FormatMacroFunction(*function, name); + } + } + return MacroBindResult(error); + } + // found a matching index - check if the default values exist within the macro + auto macro_idx = result_idx.GetIndex(); + auto ¯o_def = *functions[macro_idx]; + for (auto &default_val : defaults) { + auto entry = macro_def.default_parameters.find(default_val.first); + if (entry == macro_def.default_parameters.end()) { + string error = + StringUtil::Format("Macro \"%s\" does not have a named parameter \"%s\"\n", name, default_val.first); + error += "\nMacro definition: " + FormatMacroFunction(macro_def, name); + return MacroBindResult(error); + } } - // Add the default values for parameters that have defaults, that were not explicitly assigned to for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { auto ¶meter_name = it->first; @@ -66,8 +111,7 @@ string MacroFunction::ValidateArguments(MacroFunction ¯o_def, const string & defaults[parameter_name] = parameter_default->Copy(); } } - - return error; + return MacroBindResult(macro_idx); } void MacroFunction::CopyProperties(MacroFunction &other) const { @@ -80,7 +124,7 @@ void MacroFunction::CopyProperties(MacroFunction &other) const { } } -string MacroFunction::ToSQL(const string &schema, const string &name) const { +string MacroFunction::ToSQL() const { vector param_strings; for (auto ¶m : parameters) { param_strings.push_back(param->ToString()); @@ -88,8 +132,7 @@ string MacroFunction::ToSQL(const string &schema, const string &name) const { for (auto &named_param : default_parameters) { param_strings.push_back(StringUtil::Format("%s := %s", named_param.first, named_param.second->ToString())); } - - return StringUtil::Format("CREATE MACRO %s.%s(%s) AS ", schema, name, StringUtil::Join(param_strings, ", ")); + return StringUtil::Format("(%s) AS ", StringUtil::Join(param_strings, ", ")); } } // namespace duckdb diff --git a/src/duckdb/src/function/pragma/pragma_queries.cpp b/src/duckdb/src/function/pragma/pragma_queries.cpp index fdb47ae3..cc52edba 100644 --- a/src/duckdb/src/function/pragma/pragma_queries.cpp +++ b/src/duckdb/src/function/pragma/pragma_queries.cpp @@ -43,6 +43,7 @@ string PragmaShowTables() { ORDER BY "name";)EOF"; // clang-format on } + string PragmaShowTables(ClientContext &context, const FunctionParameters ¶meters) { return PragmaShowTables(); } @@ -91,6 +92,9 @@ string PragmaShowDatabases(ClientContext &context, const FunctionParameters &par return PragmaShowDatabases(); } +string PragmaShowVariables() { + return "SELECT * FROM duckdb_variables() ORDER BY name"; +} string PragmaAllProfiling(ClientContext &context, const FunctionParameters ¶meters) { return "SELECT * FROM pragma_last_profiling_output() JOIN pragma_detailed_profiling_output() ON " "(pragma_last_profiling_output.operator_id);"; @@ -124,6 +128,11 @@ string PragmaVersion(ClientContext &context, const FunctionParameters ¶meter return "SELECT * FROM pragma_version();"; } +string PragmaExtensionVersions(ClientContext &context, const FunctionParameters ¶meters) { + return "select extension_name, extension_version, install_mode, installed_from from duckdb_extensions() where " + "installed"; +} + string PragmaPlatform(ClientContext &context, const FunctionParameters ¶meters) { return "SELECT * FROM pragma_platform();"; } @@ -203,6 +212,7 @@ void PragmaQueries::RegisterFunction(BuiltinFunctions &set) { set.AddFunction(PragmaFunction::PragmaStatement("collations", PragmaCollations)); set.AddFunction(PragmaFunction::PragmaCall("show", PragmaShow, {LogicalType::VARCHAR})); set.AddFunction(PragmaFunction::PragmaStatement("version", PragmaVersion)); + set.AddFunction(PragmaFunction::PragmaStatement("extension_versions", PragmaExtensionVersions)); set.AddFunction(PragmaFunction::PragmaStatement("platform", PragmaPlatform)); set.AddFunction(PragmaFunction::PragmaStatement("database_size", PragmaDatabaseSize)); set.AddFunction(PragmaFunction::PragmaStatement("functions", PragmaFunctionsQuery)); diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp index a8e046b6..907b8c45 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -223,6 +223,7 @@ static void CMStringDecompressSerialize(Serializer &serializer, const optional_p unique_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { function.arguments = deserializer.ReadProperty>(100, "arguments"); function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); + function.return_type = deserializer.Get(); return nullptr; } diff --git a/src/duckdb/src/function/scalar/generic/binning.cpp b/src/duckdb/src/function/scalar/generic/binning.cpp new file mode 100644 index 00000000..aaa9d19d --- /dev/null +++ b/src/duckdb/src/function/scalar/generic/binning.cpp @@ -0,0 +1,507 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/generic_executor.hpp" +#include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +static hugeint_t GetPreviousPowerOfTen(hugeint_t input) { + hugeint_t power_of_ten = 1; + while (power_of_ten < input) { + power_of_ten *= 10; + } + return power_of_ten / 10; +} + +enum class NiceRounding { CEILING, ROUND }; + +hugeint_t RoundToNumber(hugeint_t input, hugeint_t num, NiceRounding rounding) { + if (rounding == NiceRounding::ROUND) { + return (input + (num / 2)) / num * num; + } else { + return (input + (num - 1)) / num * num; + } +} + +hugeint_t MakeNumberNice(hugeint_t input, hugeint_t step, NiceRounding rounding) { + // we consider numbers nice if they are divisible by 2 or 5 times the power-of-ten one lower than the current + // e.g. 120 is a nice number because it is divisible by 20 + // 122 is not a nice number -> we make it nice by turning it into 120 [/20] + // 153 is not a nice number -> we make it nice by turning it into 150 [/50] + // 1220 is not a nice number -> we turn it into 1200 [/200] + // first figure out the previous power of 10 (i.e. for 67 we return 10) + // now the power of ten is the power BELOW the current number + // i.e. for 67, it is not 10 + // now we can get the 2 or 5 divisors + hugeint_t power_of_ten = GetPreviousPowerOfTen(step); + hugeint_t two = power_of_ten * 2; + hugeint_t five = power_of_ten; + if (power_of_ten * 3 <= step) { + two *= 5; + } + if (power_of_ten * 2 <= step) { + five *= 5; + } + + // compute the closest round number by adding the divisor / 2 and truncating + // do this for both divisors + hugeint_t round_to_two = RoundToNumber(input, two, rounding); + hugeint_t round_to_five = RoundToNumber(input, five, rounding); + // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) + if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { + return round_to_two; + } else { + return round_to_five; + } +} + +static double GetPreviousPowerOfTen(double input) { + double power_of_ten = 1; + if (input < 1) { + while (power_of_ten > input) { + power_of_ten /= 10; + } + return power_of_ten; + } + while (power_of_ten < input) { + power_of_ten *= 10; + } + return power_of_ten / 10; +} + +double RoundToNumber(double input, double num, NiceRounding rounding) { + double result; + if (rounding == NiceRounding::ROUND) { + result = std::round(input / num) * num; + } else { + result = std::ceil(input / num) * num; + } + if (!Value::IsFinite(result)) { + return input; + } + return result; +} + +double MakeNumberNice(double input, const double step, NiceRounding rounding) { + if (input == 0) { + return 0; + } + // now the power of ten is the power BELOW the current number + // i.e. for 67, it is not 10 + // now we can get the 2 or 5 divisors + double power_of_ten = GetPreviousPowerOfTen(step); + double two = power_of_ten * 2; + double five = power_of_ten; + if (power_of_ten * 3 <= step) { + two *= 5; + } + if (power_of_ten * 2 <= step) { + five *= 5; + } + + double round_to_two = RoundToNumber(input, two, rounding); + double round_to_five = RoundToNumber(input, five, rounding); + // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) + if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { + return round_to_two; + } else { + return round_to_five; + } +} + +struct EquiWidthBinsInteger { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::BIGINT; + + static vector> Operation(const Expression &expr, int64_t input_min, int64_t input_max, + idx_t bin_count, bool nice_rounding) { + vector> result; + // to prevent integer truncation from affecting the bin boundaries we calculate them with numbers multiplied by + // 1000 we then divide to get the actual boundaries + const auto FACTOR = hugeint_t(1000); + auto min = hugeint_t(input_min) * FACTOR; + auto max = hugeint_t(input_max) * FACTOR; + + const hugeint_t span = max - min; + hugeint_t step = span / Hugeint::Convert(bin_count); + if (nice_rounding) { + // when doing nice rounding we try to make the max/step values nicer + hugeint_t new_step = MakeNumberNice(step, step, NiceRounding::ROUND); + hugeint_t new_max = RoundToNumber(max, new_step, NiceRounding::CEILING); + if (new_max != min && new_step != 0) { + max = new_max; + step = new_step; + } + // we allow for more bins when doing nice rounding since the bin count is approximate + bin_count *= 2; + } + for (hugeint_t bin_boundary = max; bin_boundary > min; bin_boundary -= step) { + const hugeint_t target_boundary = bin_boundary / FACTOR; + int64_t real_boundary = Hugeint::Cast(target_boundary); + if (!result.empty()) { + if (real_boundary < input_min || result.size() >= bin_count) { + // we can never generate input_min + break; + } + if (real_boundary == result.back().val) { + // we cannot generate the same value multiple times in a row - skip this step + continue; + } + } + result.push_back(real_boundary); + } + return result; + } +}; + +struct EquiWidthBinsDouble { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::DOUBLE; + + static vector> Operation(const Expression &expr, double min, double input_max, + idx_t bin_count, bool nice_rounding) { + double max = input_max; + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + throw InvalidInputException("equi_width_bucket does not support infinite or nan as min/max value"); + } + vector> result; + const double span = max - min; + double step; + if (!Value::IsFinite(span)) { + // max - min does not fit + step = max / static_cast(bin_count) - min / static_cast(bin_count); + } else { + step = span / static_cast(bin_count); + } + const double step_power_of_ten = GetPreviousPowerOfTen(step); + if (nice_rounding) { + // when doing nice rounding we try to make the max/step values nicer + step = MakeNumberNice(step, step, NiceRounding::ROUND); + max = RoundToNumber(input_max, step, NiceRounding::CEILING); + // we allow for more bins when doing nice rounding since the bin count is approximate + bin_count *= 2; + } + if (step == 0) { + throw InternalException("step is 0!?"); + } + + const double round_multiplication = 10 / step_power_of_ten; + for (double bin_boundary = max; bin_boundary > min; bin_boundary -= step) { + // because floating point addition adds inaccuracies, we add rounding at every step + double real_boundary = bin_boundary; + if (nice_rounding) { + real_boundary = std::round(bin_boundary * round_multiplication) / round_multiplication; + } + if (!result.empty() && result.back().val == real_boundary) { + // skip this step + continue; + } + if (real_boundary <= min || result.size() >= bin_count) { + // we can never generate below input_min + break; + } + result.push_back(real_boundary); + } + return result; + } +}; + +void NextMonth(int32_t &year, int32_t &month) { + month++; + if (month == 13) { + year++; + month = 1; + } +} + +void NextDay(int32_t &year, int32_t &month, int32_t &day) { + day++; + if (!Date::IsValid(year, month, day)) { + // day is out of range for month, move to next month + NextMonth(year, month); + day = 1; + } +} + +void NextHour(int32_t &year, int32_t &month, int32_t &day, int32_t &hour) { + hour++; + if (hour >= 24) { + NextDay(year, month, day); + hour = 0; + } +} + +void NextMinute(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute) { + minute++; + if (minute >= 60) { + NextHour(year, month, day, hour); + minute = 0; + } +} + +void NextSecond(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute, int32_t &sec) { + sec++; + if (sec >= 60) { + NextMinute(year, month, day, hour, minute); + sec = 0; + } +} + +timestamp_t MakeTimestampNice(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t minute, int32_t sec, + int32_t micros, interval_t step) { + // how to make a timestamp nice depends on the step + if (step.months >= 12) { + // if the step involves one year or more, ceil to months + // set time component to 00:00:00.00 + if (day > 1 || hour > 0 || minute > 0 || sec > 0 || micros > 0) { + // move to next month + NextMonth(year, month); + hour = minute = sec = micros = 0; + day = 1; + } + } else if (step.months > 0 || step.days >= 1) { + // if the step involves more than one day, ceil to days + if (hour > 0 || minute > 0 || sec > 0 || micros > 0) { + NextDay(year, month, day); + hour = minute = sec = micros = 0; + } + } else if (step.days > 0 || step.micros >= Interval::MICROS_PER_HOUR) { + // if the step involves more than one hour, ceil to hours + if (minute > 0 || sec > 0 || micros > 0) { + NextHour(year, month, day, hour); + minute = sec = micros = 0; + } + } else if (step.micros >= Interval::MICROS_PER_MINUTE) { + // if the step involves more than one minute, ceil to minutes + if (sec > 0 || micros > 0) { + NextMinute(year, month, day, hour, minute); + sec = micros = 0; + } + } else if (step.micros >= Interval::MICROS_PER_SEC) { + // if the step involves more than one second, ceil to seconds + if (micros > 0) { + NextSecond(year, month, day, hour, minute, sec); + micros = 0; + } + } + return Timestamp::FromDatetime(Date::FromDate(year, month, day), Time::FromTime(hour, minute, sec, micros)); +} + +int64_t RoundNumberToDivisor(int64_t number, int64_t divisor) { + return (number + (divisor / 2)) / divisor * divisor; +} + +interval_t MakeIntervalNice(interval_t interval) { + if (interval.months >= 6) { + // if we have more than 6 months, we don't care about days + interval.days = 0; + interval.micros = 0; + } else if (interval.months > 0 || interval.days >= 5) { + // if we have any months or more than 5 days, we don't care about micros + interval.micros = 0; + } else if (interval.days > 0 || interval.micros >= 6 * Interval::MICROS_PER_HOUR) { + // if we any days or more than 6 hours, we want micros to be roundable by hours at least + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_HOUR); + } else if (interval.micros >= Interval::MICROS_PER_HOUR) { + // if we have more than an hour, we want micros to be divisible by quarter hours + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE * 15); + } else if (interval.micros >= Interval::MICROS_PER_MINUTE * 10) { + // if we have more than 10 minutes, we want micros to be divisible by minutes + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE); + } else if (interval.micros >= Interval::MICROS_PER_MINUTE) { + // if we have more than a minute, we want micros to be divisible by quarter minutes + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC * 15); + } else if (interval.micros >= Interval::MICROS_PER_SEC * 10) { + // if we have more than 10 seconds, we want micros to be divisible by seconds + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC); + } + return interval; +} + +void GetTimestampComponents(timestamp_t input, int32_t &year, int32_t &month, int32_t &day, int32_t &hour, + int32_t &minute, int32_t &sec, int32_t µs) { + date_t date; + dtime_t time; + + Timestamp::Convert(input, date, time); + Date::Convert(date, year, month, day); + Time::Convert(time, hour, minute, sec, micros); +} + +struct EquiWidthBinsTimestamp { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::TIMESTAMP; + + static vector> Operation(const Expression &expr, timestamp_t input_min, + timestamp_t input_max, idx_t bin_count, bool nice_rounding) { + if (!Value::IsFinite(input_min) || !Value::IsFinite(input_max)) { + throw InvalidInputException(expr, "equi_width_bucket does not support infinite or nan as min/max value"); + } + + if (!nice_rounding) { + // if we are not doing nice rounding it is pretty simple - just interpolate between the timestamp values + auto interpolated_values = + EquiWidthBinsInteger::Operation(expr, input_min.value, input_max.value, bin_count, false); + + vector> result; + for (auto &val : interpolated_values) { + result.push_back(timestamp_t(val.val)); + } + return result; + } + // fetch the components of the timestamps + int32_t min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros; + int32_t max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros; + GetTimestampComponents(input_min, min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros); + GetTimestampComponents(input_max, max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros); + + // get the interval differences per component + // note: these can be negative (except for the largest non-zero difference) + interval_t interval_diff; + interval_diff.months = (max_year - min_year) * Interval::MONTHS_PER_YEAR + (max_month - min_month); + interval_diff.days = max_day - min_day; + interval_diff.micros = (max_hour - min_hour) * Interval::MICROS_PER_HOUR + + (max_minute - min_minute) * Interval::MICROS_PER_MINUTE + + (max_sec - min_sec) * Interval::MICROS_PER_SEC + (max_micros - min_micros); + + double step_months = static_cast(interval_diff.months) / static_cast(bin_count); + double step_days = static_cast(interval_diff.days) / static_cast(bin_count); + double step_micros = static_cast(interval_diff.micros) / static_cast(bin_count); + // since we truncate the months/days, propagate any fractional component to the unit below (i.e. 0.2 months + // becomes 6 days) + if (step_months > 0) { + double overflow_months = step_months - std::floor(step_months); + step_days += overflow_months * Interval::DAYS_PER_MONTH; + } + if (step_days > 0) { + double overflow_days = step_days - std::floor(step_days); + step_micros += overflow_days * Interval::MICROS_PER_DAY; + } + interval_t step; + step.months = static_cast(step_months); + step.days = static_cast(step_days); + step.micros = static_cast(step_micros); + + // now we make the max, and the step nice + step = MakeIntervalNice(step); + timestamp_t timestamp_val = + MakeTimestampNice(max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros, step); + if (step.months <= 0 && step.days <= 0 && step.micros <= 0) { + // interval must be at least one microsecond + step.months = step.days = 0; + step.micros = 1; + } + + vector> result; + while (timestamp_val.value >= input_min.value && result.size() < bin_count) { + result.push_back(timestamp_val); + timestamp_val = SubtractOperator::Operation(timestamp_val, step); + } + return result; + } +}; + +unique_ptr BindEquiWidthFunction(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + // while internally the bins are computed over a unified type + // the equi_width_bins function returns the same type as the input MAX + LogicalType child_type; + switch (arguments[1]->return_type.id()) { + case LogicalTypeId::UNKNOWN: + case LogicalTypeId::SQLNULL: + return nullptr; + case LogicalTypeId::DECIMAL: + // for decimals we promote to double because + child_type = LogicalType::DOUBLE; + break; + default: + child_type = arguments[1]->return_type; + break; + } + bound_function.return_type = LogicalType::LIST(child_type); + return nullptr; +} + +template +static void EquiWidthBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + static constexpr int64_t MAX_BIN_COUNT = 1000000; + auto &min_arg = args.data[0]; + auto &max_arg = args.data[1]; + auto &bin_count = args.data[2]; + auto &nice_rounding = args.data[3]; + + Vector intermediate_result(LogicalType::LIST(OP::LOGICAL_TYPE)); + GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, PrimitiveType, + GenericListType>>( + min_arg, max_arg, bin_count, nice_rounding, intermediate_result, args.size(), + [&](PrimitiveType min_p, PrimitiveType max_p, PrimitiveType bins_p, + PrimitiveType nice_rounding_p) { + if (max_p.val < min_p.val) { + throw InvalidInputException(state.expr, + "Invalid input for bin function - max value is smaller than min value"); + } + if (bins_p.val <= 0) { + throw InvalidInputException(state.expr, "Invalid input for bin function - there must be > 0 bins"); + } + if (bins_p.val > MAX_BIN_COUNT) { + throw InvalidInputException(state.expr, "Invalid input for bin function - max bin count of %d exceeded", + MAX_BIN_COUNT); + } + GenericListType> result_bins; + if (max_p.val == min_p.val) { + // if max = min return a single bucket + result_bins.values.push_back(max_p.val); + } else { + result_bins.values = OP::Operation(state.expr, min_p.val, max_p.val, static_cast(bins_p.val), + nice_rounding_p.val); + // last bin should always be the input max + if (result_bins.values[0].val < max_p.val) { + result_bins.values[0].val = max_p.val; + } + std::reverse(result_bins.values.begin(), result_bins.values.end()); + } + return result_bins; + }); + VectorOperations::DefaultCast(intermediate_result, result, args.size()); +} + +static void UnsupportedEquiWidth(DataChunk &args, ExpressionState &state, Vector &) { + throw BinderException(state.expr, "Unsupported type \"%s\" for equi_width_bins", args.data[0].GetType()); +} + +void EquiWidthBinSerialize(Serializer &, const optional_ptr, const ScalarFunction &) { + return; +} + +unique_ptr EquiWidthBinDeserialize(Deserializer &deserializer, ScalarFunction &function) { + function.return_type = deserializer.Get(); + return nullptr; +} + +ScalarFunctionSet EquiWidthBinsFun::GetFunctions() { + ScalarFunctionSet functions("equi_width_bins"); + functions.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, + BindEquiWidthFunction)); + functions.AddFunction(ScalarFunction( + {LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, BindEquiWidthFunction)); + functions.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, + BindEquiWidthFunction)); + functions.AddFunction( + ScalarFunction({LogicalType::ANY_PARAMS(LogicalType::ANY, 150), LogicalType::ANY_PARAMS(LogicalType::ANY, 150), + LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), UnsupportedEquiWidth, BindEquiWidthFunction)); + for (auto &function : functions.functions) { + function.serialize = EquiWidthBinSerialize; + function.deserialize = EquiWidthBinDeserialize; + } + return functions; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic/getvariable.cpp b/src/duckdb/src/function/scalar/generic/getvariable.cpp new file mode 100644 index 00000000..b46ab60d --- /dev/null +++ b/src/duckdb/src/function/scalar/generic/getvariable.cpp @@ -0,0 +1,58 @@ +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/transaction/meta_transaction.hpp" + +namespace duckdb { + +struct GetVariableBindData : FunctionData { + explicit GetVariableBindData(Value value_p) : value(std::move(value_p)) { + } + + Value value; + + bool Equals(const FunctionData &other_p) const override { + const auto &other = other_p.Cast(); + return Value::NotDistinctFrom(value, other.value); + } + + unique_ptr Copy() const override { + return make_uniq(value); + } +}; + +static unique_ptr GetVariableBind(ClientContext &context, ScalarFunction &function, + vector> &arguments) { + if (!arguments[0]->IsFoldable()) { + throw NotImplementedException("getvariable requires a constant input"); + } + if (arguments[0]->HasParameter()) { + throw ParameterNotResolvedException(); + } + Value value; + auto variable_name = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (!variable_name.IsNull()) { + ClientConfig::GetConfig(context).GetUserVariable(variable_name.ToString(), value); + } + function.return_type = value.type(); + return make_uniq(std::move(value)); +} + +unique_ptr BindGetVariableExpression(FunctionBindExpressionInput &input) { + if (!input.bind_data) { + // unknown type + throw InternalException("input.bind_data should be set"); + } + auto &bind_data = input.bind_data->Cast(); + // emit a constant expression + return make_uniq(bind_data.value); +} + +void GetVariableFun::RegisterFunction(BuiltinFunctions &set) { + ScalarFunction getvar("getvariable", {LogicalType::VARCHAR}, LogicalType::ANY, nullptr, GetVariableBind, nullptr); + getvar.bind_expression = BindGetVariableExpression; + set.AddFunction(getvar); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic_functions.cpp b/src/duckdb/src/function/scalar/generic_functions.cpp index a128aa56..27330ab1 100644 --- a/src/duckdb/src/function/scalar/generic_functions.cpp +++ b/src/duckdb/src/function/scalar/generic_functions.cpp @@ -5,6 +5,7 @@ namespace duckdb { void BuiltinFunctions::RegisterGenericFunctions() { Register(); Register(); + Register(); } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index b7f45458..c10e467b 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -1,83 +1,69 @@ #include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/scalar/list/contains_or_position.hpp" namespace duckdb { -static void ListContainsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - (void)state; - return ListContainsOrPosition(args, result); -} +template +static void ListSearchFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto target_count = input.size(); + auto &list_vec = input.data[0]; + auto &source_vec = ListVector::GetEntry(list_vec); + auto &target_vec = input.data[1]; -static void ListPositionFunction(DataChunk &args, ExpressionState &state, Vector &result) { - (void)state; - return ListContainsOrPosition(args, result); + ListSearchOp(list_vec, source_vec, target_vec, result, target_count); + + if (target_count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } } -template -static unique_ptr ListContainsOrPositionBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +static unique_ptr ListSearchBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { D_ASSERT(bound_function.arguments.size() == 2); // If the first argument is an array, cast it to a list arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - const auto &list = arguments[0]->return_type; // change to list + const auto &list = arguments[0]->return_type; const auto &value = arguments[1]->return_type; - if (list.id() == LogicalTypeId::UNKNOWN) { - bound_function.return_type = RETURN_TYPE; - if (value.id() != LogicalTypeId::UNKNOWN) { + + const auto list_is_param = list.id() == LogicalTypeId::UNKNOWN; + const auto value_is_param = value.id() == LogicalTypeId::UNKNOWN; + + if (list_is_param) { + if (!value_is_param) { // only list is a parameter, cast it to a list of value type bound_function.arguments[0] = LogicalType::LIST(value); bound_function.arguments[1] = value; } - } else if (value.id() == LogicalTypeId::UNKNOWN) { + } else if (value_is_param) { // only value is a parameter: we expect the child type of list - auto const &child_type = ListType::GetChildType(list); bound_function.arguments[0] = list; - bound_function.arguments[1] = child_type; - bound_function.return_type = RETURN_TYPE; + bound_function.arguments[1] = ListType::GetChildType(list); } else { - auto const &child_type = ListType::GetChildType(list); LogicalType max_child_type; - if (!LogicalType::TryGetMaxLogicalType(context, child_type, value, max_child_type)) { + if (!LogicalType::TryGetMaxLogicalType(context, ListType::GetChildType(list), value, max_child_type)) { throw BinderException( - "Cannot get list_position of element of type %s in a list of type %s[] - an explicit cast is required", - value.ToString(), child_type.ToString()); + "%s: Cannot match element of type '%s' in a list of type '%s' - an explicit cast is required", + bound_function.name, value.ToString(), list.ToString()); } - auto list_type = LogicalType::LIST(max_child_type); - - bound_function.arguments[0] = list_type; - bound_function.arguments[1] = value == max_child_type ? value : max_child_type; - // list_contains and list_position only differ in their return type - bound_function.return_type = RETURN_TYPE; + bound_function.arguments[0] = LogicalType::LIST(max_child_type); + bound_function.arguments[1] = max_child_type; } return make_uniq(bound_function.return_type); } -static unique_ptr ListContainsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return ListContainsOrPositionBind(context, bound_function, arguments); -} - -static unique_ptr ListPositionBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return ListContainsOrPositionBind(context, bound_function, arguments); -} - ScalarFunction ListContainsFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list - LogicalType::BOOLEAN, // return type - ListContainsFunction, ListContainsBind, nullptr); + return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::BOOLEAN, + ListSearchFunction, ListSearchBind); } ScalarFunction ListPositionFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, // argument list - LogicalType::INTEGER, // return type - ListPositionFunction, ListPositionBind, nullptr); + return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::INTEGER, + ListSearchFunction, ListSearchBind); } void ListContainsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp index aa8278ce..37176f23 100644 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -1,171 +1,105 @@ #include "duckdb/common/pair.hpp" #include "duckdb/common/string_util.hpp" - #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/binary_executor.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/function/scalar/string_functions.hpp" #include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/storage/statistics/list_stats.hpp" namespace duckdb { -template -void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data, - Vector &child_vector, idx_t list_size, Vector &result) { - UnifiedVectorFormat child_format; - child_vector.ToUnifiedFormat(list_size, child_format); - - T *result_data; +static optional_idx TryGetChildOffset(const list_entry_t &list_entry, const int64_t offset) { + // 1-based indexing + if (offset == 0) { + return optional_idx::Invalid(); + } - result.SetVectorType(VectorType::FLAT_VECTOR); - if (!VALIDITY_ONLY) { - result_data = FlatVector::GetData(result); + const auto index_offset = (offset > 0) ? offset - 1 : offset; + if (index_offset < 0) { + const auto signed_list_length = UnsafeNumericCast(list_entry.length); + if (signed_list_length + index_offset < 0) { + return optional_idx::Invalid(); + } + return optional_idx(list_entry.offset + UnsafeNumericCast(signed_list_length + index_offset)); } - auto &result_mask = FlatVector::Validity(result); - // heap-ref once - if (HEAP_REF) { - StringVector::AddHeapReference(result, child_vector); + const auto unsigned_offset = UnsafeNumericCast(index_offset); + + // Check that the offset is within the list + if (unsigned_offset >= list_entry.length) { + return optional_idx::Invalid(); } - // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise - // should have gone with GetValue perhaps - auto child_data = UnifiedVectorFormat::GetData(child_format); + return optional_idx(list_entry.offset + unsigned_offset); +} + +static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { + D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); + UnifiedVectorFormat list_data; + UnifiedVectorFormat offsets_data; + + list.ToUnifiedFormat(count, list_data); + offsets.ToUnifiedFormat(count, offsets_data); + + const auto list_ptr = UnifiedVectorFormat::GetData(list_data); + const auto offsets_ptr = UnifiedVectorFormat::GetData(offsets_data); + + UnifiedVectorFormat child_data; + auto &child_vector = ListVector::GetEntry(list); + auto child_count = ListVector::GetListSize(list); + child_vector.ToUnifiedFormat(child_count, child_data); + + SelectionVector sel(count); + vector invalid_offsets; + + optional_idx first_valid_child_idx; for (idx_t i = 0; i < count; i++) { - auto list_index = list_data.sel->get_index(i); - auto offsets_index = offsets_data.sel->get_index(i); - if (!list_data.validity.RowIsValid(list_index)) { - result_mask.SetInvalid(i); - continue; - } - if (!offsets_data.validity.RowIsValid(offsets_index)) { - result_mask.SetInvalid(i); + const auto list_index = list_data.sel->get_index(i); + const auto offsets_index = offsets_data.sel->get_index(i); + + if (!list_data.validity.RowIsValid(list_index) || !offsets_data.validity.RowIsValid(offsets_index)) { + invalid_offsets.push_back(i); continue; } - auto list_entry = (UnifiedVectorFormat::GetData(list_data))[list_index]; - auto offsets_entry = (UnifiedVectorFormat::GetData(offsets_data))[offsets_index]; - // 1-based indexing - if (offsets_entry == 0) { - result_mask.SetInvalid(i); + const auto child_offset = TryGetChildOffset(list_ptr[list_index], offsets_ptr[offsets_index]); + + if (!child_offset.IsValid()) { + invalid_offsets.push_back(i); continue; } - offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry; - - idx_t child_offset; - if (offsets_entry < 0) { - if (offsets_entry < -int64_t(list_entry.length)) { - result_mask.SetInvalid(i); - continue; - } - child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset + list_entry.length) + - offsets_entry); - } else { - if ((idx_t)offsets_entry >= list_entry.length) { - result_mask.SetInvalid(i); - continue; - } - child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset) + offsets_entry); - } - auto child_index = child_format.sel->get_index(child_offset); - if (child_format.validity.RowIsValid(child_index)) { - if (!VALIDITY_ONLY) { - result_data[i] = child_data[child_index]; - } - } else { - result_mask.SetInvalid(i); + + const auto child_idx = child_data.sel->get_index(child_offset.GetIndex()); + sel.set_index(i, child_idx); + + if (!first_valid_child_idx.IsValid()) { + // Save the first valid child as a dummy index to copy in VectorOperations::Copy later + first_valid_child_idx = child_idx; } } - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} -static void ExecuteListExtractInternal(const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets, - Vector &child_vector, idx_t list_size, Vector &result) { - D_ASSERT(child_vector.GetType() == result.GetType()); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT16: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT32: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT64: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INT128: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT8: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT16: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT32: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT64: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::UINT128: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::FLOAT: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::DOUBLE: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::VARCHAR: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::INTERVAL: - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(child_vector); - auto &result_entries = StructVector::GetEntries(result); - D_ASSERT(entries.size() == result_entries.size()); - // extract the child entries of the struct - for (idx_t i = 0; i < entries.size(); i++) { - ExecuteListExtractInternal(count, list, offsets, *entries[i], list_size, *result_entries[i]); + + if (first_valid_child_idx.IsValid()) { + // Only copy if we found at least one valid child + for (const auto &invalid_offset : invalid_offsets) { + sel.set_index(invalid_offset, first_valid_child_idx.GetIndex()); } - // extract the validity mask - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; + VectorOperations::Copy(child_vector, result, sel, count, 0, 0); } - case PhysicalType::LIST: { - // nested list: we have to reference the child - auto &child_child_list = ListVector::GetEntry(child_vector); - ListVector::GetEntry(result).Reference(child_child_list); - ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); - ListExtractTemplate(count, list, offsets, child_vector, list_size, result); - break; - } - default: - throw NotImplementedException("Unimplemented type for LIST_EXTRACT"); + // Copy:ing the vectors also copies the validity mask, so we set the rows with invalid offsets (0) to false here. + for (const auto &invalid_idx : invalid_offsets) { + FlatVector::SetNull(result, invalid_idx, true); } -} -static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { - D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); - UnifiedVectorFormat list_data; - UnifiedVectorFormat offsets_data; + if (count == 1 || (list.GetVectorType() == VectorType::CONSTANT_VECTOR && + offsets.GetVectorType() == VectorType::CONSTANT_VECTOR)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } - list.ToUnifiedFormat(count, list_data); - offsets.ToUnifiedFormat(count, offsets_data); - ExecuteListExtractInternal(count, list_data, offsets_data, ListVector::GetEntry(list), - ListVector::GetListSize(list), result); result.Verify(count); } @@ -180,13 +114,6 @@ static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector D_ASSERT(args.ColumnCount() == 2); auto count = args.size(); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - } - Vector &base = args.data[0]; Vector &subscript = args.data[1]; diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp index 792e0f6b..86d21fd5 100644 --- a/src/duckdb/src/function/scalar/list/list_resize.cpp +++ b/src/duckdb/src/function/scalar/list/list_resize.cpp @@ -6,105 +6,114 @@ namespace duckdb { -void ListResizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data[1].GetType().id() == LogicalTypeId::UBIGINT); +void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &result) { + + // Early-out, if the return value is a constant NULL. if (result.GetType().id() == LogicalTypeId::SQLNULL) { result.SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::SetNull(result, true); return; } - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto count = args.size(); - - result.SetVectorType(VectorType::FLAT_VECTOR); auto &lists = args.data[0]; - auto &child = ListVector::GetEntry(args.data[0]); auto &new_sizes = args.data[1]; + auto row_count = args.size(); - UnifiedVectorFormat list_data; - lists.ToUnifiedFormat(count, list_data); - auto list_entries = UnifiedVectorFormat::GetData(list_data); - - UnifiedVectorFormat new_size_data; - new_sizes.ToUnifiedFormat(count, new_size_data); - auto new_size_entries = UnifiedVectorFormat::GetData(new_size_data); + UnifiedVectorFormat lists_data; + lists.ToUnifiedFormat(row_count, lists_data); + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + auto &child_vector = ListVector::GetEntry(lists); UnifiedVectorFormat child_data; - child.ToUnifiedFormat(count, child_data); - - // Find the new size of the result child vector - idx_t new_child_size = 0; - for (idx_t i = 0; i < count; i++) { - auto index = new_size_data.sel->get_index(i); - if (new_size_data.validity.RowIsValid(index)) { - new_child_size += new_size_entries[index]; + child_vector.ToUnifiedFormat(row_count, child_data); + + UnifiedVectorFormat new_sizes_data; + new_sizes.ToUnifiedFormat(row_count, new_sizes_data); + D_ASSERT(new_sizes.GetType().id() == LogicalTypeId::UBIGINT); + auto new_size_entries = UnifiedVectorFormat::GetData(new_sizes_data); + + // Get the new size of the result child vector. + // We skip rows with NULL values in the input lists. + idx_t child_vector_size = 0; + for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { + auto list_idx = lists_data.sel->get_index(row_idx); + auto new_size_idx = new_sizes_data.sel->get_index(row_idx); + + if (lists_data.validity.RowIsValid(list_idx) && new_sizes_data.validity.RowIsValid(new_size_idx)) { + child_vector_size += new_size_entries[new_size_idx]; } } + ListVector::Reserve(result, child_vector_size); + ListVector::SetListSize(result, child_vector_size); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + auto &result_child_vector = ListVector::GetEntry(result); - // Create the default vector if it exists + // Get the default values, if provided. UnifiedVectorFormat default_data; optional_ptr default_vector; if (args.ColumnCount() == 3) { default_vector = &args.data[2]; - default_vector->Flatten(count); - default_vector->ToUnifiedFormat(count, default_data); - default_vector->SetVectorType(VectorType::CONSTANT_VECTOR); + default_vector->ToUnifiedFormat(row_count, default_data); } - ListVector::Reserve(result, new_child_size); - ListVector::SetListSize(result, new_child_size); + idx_t offset = 0; + for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto result_entries = FlatVector::GetData(result); - auto &result_child = ListVector::GetEntry(result); - - // for each lists in the args - idx_t result_child_offset = 0; - for (idx_t args_index = 0; args_index < count; args_index++) { - auto l_index = list_data.sel->get_index(args_index); - auto new_index = new_size_data.sel->get_index(args_index); + auto list_idx = lists_data.sel->get_index(row_idx); + auto new_size_idx = new_sizes_data.sel->get_index(row_idx); - // set null if lists is null - if (!list_data.validity.RowIsValid(l_index)) { - FlatVector::SetNull(result, args_index, true); + // Set to NULL, if the list is NULL. + if (!lists_data.validity.RowIsValid(list_idx)) { + result_validity.SetInvalid(row_idx); continue; } - idx_t new_size_entry = 0; - if (new_size_data.validity.RowIsValid(new_index)) { - new_size_entry = new_size_entries[new_index]; + idx_t new_size = 0; + if (new_sizes_data.validity.RowIsValid(new_size_idx)) { + new_size = new_size_entries[new_size_idx]; } - // find the smallest size between lists and new_sizes - auto values_to_copy = MinValue(list_entries[l_index].length, new_size_entry); - - // set the result entry - result_entries[args_index].offset = result_child_offset; - result_entries[args_index].length = new_size_entry; - - // copy the values from the child vector - VectorOperations::Copy(child, result_child, list_entries[l_index].offset + values_to_copy, - list_entries[l_index].offset, result_child_offset); - result_child_offset += values_to_copy; - - // set default value if it exists - idx_t def_index = 0; - if (args.ColumnCount() == 3) { - def_index = default_data.sel->get_index(args_index); - } - - // if the new size is larger than the old size, fill in the default value - if (values_to_copy < new_size_entry) { - if (default_vector && default_data.validity.RowIsValid(def_index)) { - VectorOperations::Copy(*default_vector, result_child, new_size_entry - values_to_copy, def_index, - result_child_offset); - result_child_offset += new_size_entry - values_to_copy; - } else { - for (idx_t j = values_to_copy; j < new_size_entry; j++) { - FlatVector::SetNull(result_child, result_child_offset, true); - result_child_offset++; + // If new_size >= length, then we copy [0, length) values. + // If new_size < length, then we copy [0, new_size) values. + auto copy_count = MinValue(list_entries[list_idx].length, new_size); + + // Set the result entry. + result_entries[row_idx].offset = offset; + result_entries[row_idx].length = new_size; + + // Copy the child vector's values. + // The number of elements to copy is later determined like so: source_count - source_offset. + idx_t source_offset = list_entries[list_idx].offset; + idx_t source_count = source_offset + copy_count; + VectorOperations::Copy(child_vector, result_child_vector, source_count, source_offset, offset); + offset += copy_count; + + // Fill the remaining space with the default values. + if (copy_count < new_size) { + idx_t remaining_count = new_size - copy_count; + + if (default_vector) { + auto default_idx = default_data.sel->get_index(row_idx); + if (default_data.validity.RowIsValid(default_idx)) { + SelectionVector sel(remaining_count); + for (idx_t j = 0; j < remaining_count; j++) { + sel.set_index(j, row_idx); + } + VectorOperations::Copy(*default_vector, result_child_vector, sel, remaining_count, 0, offset); + offset += remaining_count; + continue; } } + + // Fill the remaining space with NULL. + for (idx_t j = copy_count; j < new_size; j++) { + FlatVector::SetNull(result_child_vector, offset, true); + offset++; + } } } @@ -118,23 +127,23 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun D_ASSERT(bound_function.arguments.size() == 2 || arguments.size() == 3); bound_function.arguments[1] = LogicalType::UBIGINT; - // If the first argument is an array, cast it to a list + // If the first argument is an array, cast it to a list. arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - // first argument is constant NULL + // Early-out, if the first argument is a constant NULL. if (arguments[0]->return_type == LogicalType::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; bound_function.return_type = LogicalType::SQLNULL; return make_uniq(bound_function.return_type); } - // prepared statements + // Early-out, if the first argument is a prepared statement. if (arguments[0]->return_type == LogicalType::UNKNOWN) { bound_function.return_type = arguments[0]->return_type; return make_uniq(bound_function.return_type); } - // default type does not match list type + // Attempt implicit casting, if the default type does not match list the list child type. if (bound_function.arguments.size() == 3 && ListType::GetChildType(arguments[0]->return_type) != arguments[2]->return_type && arguments[2]->return_type != LogicalTypeId::SQLNULL) { @@ -146,22 +155,22 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun } void ListResizeFun::RegisterFunction(BuiltinFunctions &set) { - ScalarFunction sfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - sfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + ScalarFunction simple_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, + LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); + simple_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - ScalarFunction dfun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - dfun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + ScalarFunction default_value_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, + LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); + default_value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; ScalarFunctionSet list_resize("list_resize"); - list_resize.AddFunction(sfun); - list_resize.AddFunction(dfun); + list_resize.AddFunction(simple_fun); + list_resize.AddFunction(default_value_fun); set.AddFunction(list_resize); ScalarFunctionSet array_resize("array_resize"); - array_resize.AddFunction(sfun); - array_resize.AddFunction(dfun); + array_resize.AddFunction(simple_fun); + array_resize.AddFunction(default_value_fun); set.AddFunction(array_resize); } diff --git a/src/duckdb/src/function/scalar/list/list_zip.cpp b/src/duckdb/src/function/scalar/list/list_zip.cpp index 7e050382..6e24689c 100644 --- a/src/duckdb/src/function/scalar/list/list_zip.cpp +++ b/src/duckdb/src/function/scalar/list/list_zip.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression_binder.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" @@ -138,6 +139,8 @@ static unique_ptr ListZipBind(ClientContext &context, ScalarFuncti auto &child = arguments[i]; switch (child->return_type.id()) { case LogicalTypeId::LIST: + case LogicalTypeId::ARRAY: + child = BoundCastExpression::AddArrayCastToList(context, std::move(child)); struct_children.push_back(make_pair(string(), ListType::GetChildType(child->return_type))); break; case LogicalTypeId::SQLNULL: diff --git a/src/duckdb/src/function/scalar/operators/arithmetic.cpp b/src/duckdb/src/function/scalar/operators/arithmetic.cpp index 828395ff..03789783 100644 --- a/src/duckdb/src/function/scalar/operators/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operators/arithmetic.cpp @@ -1,3 +1,4 @@ +#include "duckdb/common/enum_util.hpp" #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/multiply.hpp" #include "duckdb/common/operator/numeric_binary_operators.hpp" @@ -9,10 +10,10 @@ #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/enum_util.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/function/scalar/operators.hpp" +#include "duckdb/function/scalar/string_functions.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" #include @@ -857,8 +858,8 @@ hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { template <> interval_t DivideOperator::Operation(interval_t left, int64_t right) { - left.days /= right; - left.months /= right; + left.days = UnsafeNumericCast(left.days / right); + left.months = UnsafeNumericCast(left.months / right); left.micros /= right; return left; } @@ -952,12 +953,24 @@ static scalar_function_t GetBinaryFunctionIgnoreZero(PhysicalType type) { } } +template +unique_ptr BindBinaryFloatingPoint(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto &config = ClientConfig::GetConfig(context); + if (config.ieee_floating_point_ops) { + bound_function.function = GetScalarBinaryFunction(bound_function.return_type.InternalType()); + } else { + bound_function.function = GetBinaryFunctionIgnoreZero(bound_function.return_type.InternalType()); + } + return nullptr; +} + void DivideFun::RegisterFunction(BuiltinFunctions &set) { ScalarFunctionSet fp_divide("/"); - fp_divide.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, - GetBinaryFunctionIgnoreZero(PhysicalType::FLOAT))); - fp_divide.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - GetBinaryFunctionIgnoreZero(PhysicalType::DOUBLE))); + fp_divide.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, nullptr, + BindBinaryFloatingPoint)); + fp_divide.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, nullptr, + BindBinaryFloatingPoint)); fp_divide.AddFunction( ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, BinaryScalarFunctionIgnoreZero)); @@ -1000,14 +1013,12 @@ unique_ptr BindDecimalModulo(ClientContext &context, ScalarFunctio template <> float ModuloOperator::Operation(float left, float right) { - D_ASSERT(right != 0); auto result = std::fmod(left, right); return result; } template <> double ModuloOperator::Operation(double left, double right) { - D_ASSERT(right != 0); auto result = std::fmod(left, right); return result; } @@ -1023,7 +1034,9 @@ hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { void ModFun::RegisterFunction(BuiltinFunctions &set) { ScalarFunctionSet functions("%"); for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { + if (type.id() == LogicalTypeId::FLOAT || type.id() == LogicalTypeId::DOUBLE) { + functions.AddFunction(ScalarFunction({type, type}, type, nullptr, BindBinaryFloatingPoint)); + } else if (type.id() == LogicalTypeId::DECIMAL) { functions.AddFunction(ScalarFunction({type, type}, type, nullptr, BindDecimalModulo)); } else { functions.AddFunction( diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp index 4e80ed43..f25fac67 100644 --- a/src/duckdb/src/function/scalar/sequence/nextval.cpp +++ b/src/duckdb/src/function/scalar/sequence/nextval.cpp @@ -110,8 +110,8 @@ void Serialize(Serializer &serializer, const optional_ptr bind_dat } unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &) { - auto create_info = deserializer.ReadPropertyWithDefault>(100, "sequence_create_info", - unique_ptr()); + auto create_info = deserializer.ReadPropertyWithExplicitDefault>(100, "sequence_create_info", + unique_ptr()); if (!create_info) { return nullptr; } @@ -121,12 +121,12 @@ unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction return make_uniq(sequence); } -void NextValModifiedDatabases(FunctionModifiedDatabasesInput &input) { +void NextValModifiedDatabases(ClientContext &context, FunctionModifiedDatabasesInput &input) { if (!input.bind_data) { return; } auto &seq = input.bind_data->Cast(); - input.modified_databases.insert(seq.sequence.ParentCatalog().GetName()); + input.properties.RegisterDBModify(seq.sequence.ParentCatalog(), context); } void NextvalFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/scalar/strftime_format.cpp b/src/duckdb/src/function/scalar/strftime_format.cpp index d4cf4067..3525519a 100644 --- a/src/duckdb/src/function/scalar/strftime_format.cpp +++ b/src/duckdb/src/function/scalar/strftime_format.cpp @@ -1,10 +1,14 @@ #include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/to_string.hpp" #include "duckdb/common/types/cast_helpers.hpp" #include "duckdb/common/types/date.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/multiply.hpp" + #include namespace duckdb { @@ -67,15 +71,15 @@ void StrfTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifi StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); } -idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date, dtime_t time, int32_t utc_offset, +idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date, int32_t data[8], const char *tz_name) { switch (specifier) { case StrTimeSpecifier::FULL_WEEKDAY_NAME: return Date::DAY_NAMES[Date::ExtractISODayOfTheWeek(date) % 7].GetSize(); case StrTimeSpecifier::FULL_MONTH_NAME: - return Date::MONTH_NAMES[Date::ExtractMonth(date) - 1].GetSize(); + return Date::MONTH_NAMES[data[1] - 1].GetSize(); case StrTimeSpecifier::YEAR_DECIMAL: { - auto year = Date::ExtractYear(date); + auto year = data[0]; // Be consistent with WriteStandardSpecifier if (0 <= year && year <= 9999) { return 4; @@ -85,13 +89,13 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date } case StrTimeSpecifier::MONTH_DECIMAL: { idx_t len = 1; - auto month = Date::ExtractMonth(date); + auto month = data[1]; len += month >= 10; return len; } case StrTimeSpecifier::UTC_OFFSET: // ±HH or ±HH:MM - return (utc_offset % 60) ? 6 : 3; + return (data[7] % 60) ? 6 : 3; case StrTimeSpecifier::TZ_NAME: if (tz_name) { return strlen(tz_name); @@ -104,8 +108,7 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date case StrTimeSpecifier::SECOND_DECIMAL: { // time specifiers idx_t len = 1; - int32_t hour, min, sec, msec; - Time::Convert(time, hour, min, sec, msec); + int32_t hour = data[3], min = data[4], sec = data[5]; switch (specifier) { case StrTimeSpecifier::HOUR_24_DECIMAL: len += hour >= 10; @@ -129,38 +132,49 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date return len; } case StrTimeSpecifier::DAY_OF_MONTH: - return UnsafeNumericCast( - NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDay(date)))); + return UnsafeNumericCast(NumericHelper::UnsignedLength(UnsafeNumericCast(data[2]))); case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: return UnsafeNumericCast( NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDayOfTheYear(date)))); case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return UnsafeNumericCast(NumericHelper::UnsignedLength( - UnsafeNumericCast(AbsValue(Date::ExtractYear(date)) % 100))); + return UnsafeNumericCast( + NumericHelper::UnsignedLength(UnsafeNumericCast(AbsValue(data[0]) % 100))); default: throw InternalException("Unimplemented specifier for GetSpecifierLength"); } } //! Returns the total length of the date formatted by this format specifier -idx_t StrfTimeFormat::GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name) { +idx_t StrfTimeFormat::GetLength(date_t date, int32_t data[8], const char *tz_name) const { idx_t size = constant_size; if (!var_length_specifiers.empty()) { for (auto &specifier : var_length_specifiers) { - size += GetSpecifierLength(specifier, date, time, utc_offset, tz_name); + size += GetSpecifierLength(specifier, date, data, tz_name); } } return size; } -char *StrfTimeFormat::WriteString(char *target, const string_t &str) { +idx_t StrfTimeFormat::GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name) { + if (!var_length_specifiers.empty()) { + int32_t data[8]; + Date::Convert(date, data[0], data[1], data[2]); + Time::Convert(time, data[3], data[4], data[5], data[6]); + data[6] *= Interval::NANOS_PER_MICRO; + data[7] = utc_offset; + return GetLength(date, data, tz_name); + } + return constant_size; +} + +char *StrfTimeFormat::WriteString(char *target, const string_t &str) const { idx_t size = str.GetSize(); memcpy(target, str.GetData(), size); return target + size; } // write a value in the range of 0..99 unpadded (e.g. "1", "2", ... "98", "99") -char *StrfTimeFormat::Write2(char *target, uint8_t value) { +char *StrfTimeFormat::Write2(char *target, uint8_t value) const { D_ASSERT(value < 100); if (value >= 10) { return WritePadded2(target, value); @@ -171,7 +185,7 @@ char *StrfTimeFormat::Write2(char *target, uint8_t value) { } // write a value in the range of 0..99 padded to 2 digits -char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) { +char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) const { D_ASSERT(value < 100); auto index = static_cast(value * 2); *target++ = duckdb_fmt::internal::data::digits[index]; @@ -180,7 +194,7 @@ char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) { } // write a value in the range of 0..999 padded -char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) { +char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) const { D_ASSERT(value < 1000); if (value >= 100) { WritePadded2(target + 1, value % 100); @@ -194,17 +208,17 @@ char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) { } // write a value in the range of 0..999999... padded to the given number of digits -char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) { +char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) const { D_ASSERT(padding > 1); if (padding % 2) { - int decimals = value % 1000; - WritePadded3(target + padding - 3, UnsafeNumericCast(decimals)); + uint32_t decimals = value % 1000u; + WritePadded3(target + padding - 3, decimals); value /= 1000; padding -= 3; } for (size_t i = 0; i < padding / 2; i++) { - int decimals = value % 100; - WritePadded2(target + padding - 2 * (i + 1), UnsafeNumericCast(decimals)); + uint32_t decimals = value % 100u; + WritePadded2(target + padding - 2 * (i + 1), decimals); value /= 100; } return target + padding; @@ -228,7 +242,7 @@ bool StrfTimeFormat::IsDateSpecifier(StrTimeSpecifier specifier) { } } -char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target) { +char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target) const { switch (specifier) { case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: { auto dow = Date::ExtractISODayOfTheWeek(date); @@ -280,8 +294,8 @@ char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date } char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t data[], const char *tz_name, - size_t tz_len, char *target) { - // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc + size_t tz_len, char *target) const { + // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] ns, [7] utc switch (specifier) { case StrTimeSpecifier::DAY_OF_MONTH_PADDED: target = WritePadded2(target, UnsafeNumericCast(data[2])); @@ -339,13 +353,13 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t target = WritePadded2(target, UnsafeNumericCast(data[5])); break; case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6] * Interval::NANOS_PER_MICRO), 9); + target = WritePadded(target, UnsafeNumericCast(data[6]), 9); break; case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6]), 6); + target = WritePadded(target, UnsafeNumericCast(data[6] / Interval::NANOS_PER_MICRO), 6); break; case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::MICROS_PER_MSEC)); + target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::NANOS_PER_MSEC)); break; case StrTimeSpecifier::UTC_OFFSET: { *target++ = (data[7] < 0) ? '-' : '+'; @@ -404,7 +418,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t return target; } -void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_name, char *target) { +void StrfTimeFormat::FormatStringNS(date_t date, int32_t data[8], const char *tz_name, char *target) const { D_ASSERT(specifiers.size() + 1 == literals.size()); idx_t i; for (i = 0; i < specifiers.size(); i++) { @@ -423,6 +437,12 @@ void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_n memcpy(target, literals[i].c_str(), literals[i].size()); } +void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_name, char *target) { + data[6] *= Interval::NANOS_PER_MICRO; + FormatStringNS(date, data, tz_name, target); + data[6] /= Interval::NANOS_PER_MICRO; +} + void StrfTimeFormat::FormatString(date_t date, dtime_t time, char *target) { int32_t data[8]; // year, month, day, hour, min, sec, µs, offset Date::Convert(date, data[0], data[1], data[2]); @@ -440,7 +460,7 @@ string StrfTimeFormat::Format(timestamp_t timestamp, const string &format_str) { auto time = Timestamp::GetTime(timestamp); auto len = format.GetLength(date, time, 0, nullptr); - auto result = make_unsafe_uniq_array(len); + auto result = make_unsafe_uniq_array_uninitialized(len); format.FormatString(date, time, result.get()); return string(result.get(), len); } @@ -650,24 +670,68 @@ void StrfTimeFormat::ConvertDateVector(Vector &input, Vector &result, idx_t coun }); } +string_t StrfTimeFormat::ConvertTimestampValue(const timestamp_t &input, Vector &result) const { + if (Timestamp::IsFinite(input)) { + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + + int32_t data[8]; // year, month, day, hour, min, sec, ns, offset + Date::Convert(date, data[0], data[1], data[2]); + Time::Convert(time, data[3], data[4], data[5], data[6]); + data[6] *= Interval::NANOS_PER_MICRO; + data[7] = 0; + const char *tz_name = nullptr; + + idx_t len = GetLength(date, data, tz_name); + string_t target = StringVector::EmptyString(result, len); + FormatStringNS(date, data, tz_name, target.GetDataWriteable()); + target.Finalize(); + return target; + } else { + return StringVector::AddString(result, Timestamp::ToString(input)); + } +} + +string_t StrfTimeFormat::ConvertTimestampValue(const timestamp_ns_t &input, Vector &result) const { + if (Timestamp::IsFinite(input)) { + date_t date; + dtime_t time; + int32_t nanos; + Timestamp::Convert(input, date, time, nanos); + + int32_t data[8]; // year, month, day, hour, min, sec, ns, offset + Date::Convert(date, data[0], data[1], data[2]); + Time::Convert(time, data[3], data[4], data[5], data[6]); + data[6] *= Interval::NANOS_PER_MICRO; + data[6] += nanos; + data[7] = 0; + const char *tz_name = nullptr; + + idx_t len = GetLength(date, data, tz_name); + string_t target = StringVector::EmptyString(result, len); + FormatStringNS(date, data, tz_name, target.GetDataWriteable()); + target.Finalize(); + return target; + } else { + return StringVector::AddString(result, Timestamp::ToString(input)); + } +} + void StrfTimeFormat::ConvertTimestampVector(Vector &input, Vector &result, idx_t count) { D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ); D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); UnaryExecutor::ExecuteWithNulls( - input, result, count, [&](timestamp_t input, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input)) { - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - idx_t len = GetLength(date, time, 0, nullptr); - string_t target = StringVector::EmptyString(result, len); - FormatString(date, time, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - return StringVector::AddString(result, Timestamp::ToString(input)); - } - }); + input, result, count, + [&](timestamp_t input, ValidityMask &mask, idx_t idx) { return ConvertTimestampValue(input, result); }); +} + +void StrfTimeFormat::ConvertTimestampNSVector(Vector &input, Vector &result, idx_t count) { + D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP_NS); + D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); + UnaryExecutor::ExecuteWithNulls( + input, result, count, + [&](timestamp_ns_t input, ValidityMask &mask, idx_t idx) { return ConvertTimestampValue(input, result); }); } void StrpTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { @@ -743,7 +807,7 @@ int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t s return -1; } -bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) const { +bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result, bool strict) const { auto &result_data = result.data; auto &error_message = result.error_message; auto &error_position = result.error_position; @@ -892,6 +956,9 @@ bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) c } // year without century.. // Python uses 69 as a crossover point (i.e. >= 69 is 19.., < 69 is 20..) + if (pos - start_pos < 2 && strict) { + return false; + } if (number >= 100) { // %y only supports numbers between [0..99] error_message = "Year without century out of range, expected a value between 0 and 99"; @@ -916,6 +983,9 @@ bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) c default: break; } + if (pos - start_pos < 2 && strict) { + return false; + } // year as full number result_data[0] = UnsafeNumericCast(number); break; @@ -1000,19 +1070,18 @@ bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) c break; case StrTimeSpecifier::NANOSECOND_PADDED: D_ASSERT(number < Interval::NANOS_PER_SEC); // enforced by the length of the number - // microseconds (rounded) - result_data[6] = - UnsafeNumericCast((number + Interval::NANOS_PER_MICRO / 2) / Interval::NANOS_PER_MICRO); + // nanoseconds + result_data[6] = UnsafeNumericCast(number); break; case StrTimeSpecifier::MICROSECOND_PADDED: D_ASSERT(number < Interval::MICROS_PER_SEC); // enforced by the length of the number - // microseconds - result_data[6] = UnsafeNumericCast(number); + // nanoseconds + result_data[6] = UnsafeNumericCast(number * Interval::NANOS_PER_MICRO); break; case StrTimeSpecifier::MILLISECOND_PADDED: D_ASSERT(number < Interval::MSECS_PER_SEC); // enforced by the length of the number - // microseconds - result_data[6] = UnsafeNumericCast(number * Interval::MICROS_PER_MSEC); + // nanoseconds + result_data[6] = UnsafeNumericCast(number * Interval::NANOS_PER_MSEC); break; case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: @@ -1324,10 +1393,10 @@ bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) c } //! Parses a timestamp using the given specifier -bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { +bool StrpTimeFormat::Parse(string_t str, ParseResult &result, bool strict) const { auto data = str.GetData(); idx_t size = str.GetSize(); - return Parse(data, size, result); + return Parse(data, size, result, strict); } StrpTimeFormat::ParseResult StrpTimeFormat::Parse(const string &format_string, const string &text) { @@ -1366,17 +1435,27 @@ bool StrpTimeFormat::ParseResult::TryToDate(date_t &result) { return Date::TryFromDate(data[0], data[1], data[2], result); } +int32_t StrpTimeFormat::ParseResult::GetMicros() const { + return UnsafeNumericCast((data[6] + Interval::NANOS_PER_MICRO / 2) / Interval::NANOS_PER_MICRO); +} + dtime_t StrpTimeFormat::ParseResult::ToTime() { const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; - return Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); + return Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], GetMicros()); +} + +int64_t StrpTimeFormat::ParseResult::ToTimeNS() { + const int32_t hour_offset = data[7] / Interval::MINS_PER_HOUR; + const int32_t mins_offset = data[7] % Interval::MINS_PER_HOUR; + return Time::ToNanoTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); } bool StrpTimeFormat::ParseResult::TryToTime(dtime_t &result) { if (data[7]) { return false; } - result = Time::FromTime(data[3], data[4], data[5], data[6]); + result = Time::FromTime(data[3], data[4], data[5], GetMicros()); return true; } @@ -1390,7 +1469,7 @@ timestamp_t StrpTimeFormat::ParseResult::ToTimestamp() { return Timestamp::FromDatetime(special, dtime_t(0)); } - date_t date = Date::FromDate(data[0], data[1], data[2]); + date_t date = ToDate(); dtime_t time = ToTime(); return Timestamp::FromDatetime(date, time); } @@ -1404,6 +1483,49 @@ bool StrpTimeFormat::ParseResult::TryToTimestamp(timestamp_t &result) { return Timestamp::TryFromDatetime(date, time, result); } +timestamp_ns_t StrpTimeFormat::ParseResult::ToTimestampNS() { + timestamp_ns_t result; + if (is_special) { + if (special == date_t::infinity()) { + result.value = timestamp_t::infinity().value; + } else if (special == date_t::ninfinity()) { + result.value = timestamp_t::ninfinity().value; + } else { + result.value = special.days * Interval::NANOS_PER_DAY; + } + } else { + // Don't use rounded µs + const auto date = ToDate(); + const auto time = ToTimeNS(); + if (!TryMultiplyOperator::Operation(date.days, Interval::NANOS_PER_DAY, + result.value)) { + throw ConversionException("Date out of nanosecond range: %d-%d-%d", data[0], data[1], data[2]); + } + if (!TryAddOperator::Operation(result.value, time, result.value)) { + throw ConversionException("Overflow exception in date/time -> timestamp_ns conversion"); + } + } + + return result; +} + +bool StrpTimeFormat::ParseResult::TryToTimestampNS(timestamp_ns_t &result) { + date_t date; + if (!TryToDate(date)) { + return false; + } + + // Don't use rounded µs + const auto time = ToTimeNS(); + if (!TryMultiplyOperator::Operation(date.days, Interval::NANOS_PER_DAY, result.value)) { + return false; + } + if (!TryAddOperator::Operation(result.value, time, result.value)) { + return false; + } + return Timestamp::IsFinite(result); +} + string StrpTimeFormat::ParseResult::FormatError(string_t input, const string &format_specifier) { return StringUtil::Format("Could not parse string \"%s\" according to format specifier \"%s\"\n%s\nError: %s", input.GetString(), format_specifier, @@ -1453,4 +1575,21 @@ bool StrpTimeFormat::TryParseTimestamp(const char *data, size_t size, timestamp_ return parse_result.TryToTimestamp(result); } +bool StrpTimeFormat::TryParseTimestampNS(string_t input, timestamp_ns_t &result, string &error_message) const { + ParseResult parse_result; + if (!Parse(input, parse_result)) { + error_message = parse_result.FormatError(input, format_specifier); + return false; + } + return parse_result.TryToTimestampNS(result); +} + +bool StrpTimeFormat::TryParseTimestampNS(const char *data, size_t size, timestamp_ns_t &result) const { + ParseResult parse_result; + if (!Parse(data, size, parse_result)) { + return false; + } + return parse_result.TryToTimestampNS(result); +} + } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/caseconvert.cpp b/src/duckdb/src/function/scalar/string/caseconvert.cpp index fa5b612f..b6240d06 100644 --- a/src/duckdb/src/function/scalar/string/caseconvert.cpp +++ b/src/duckdb/src/function/scalar/string/caseconvert.cpp @@ -5,7 +5,7 @@ #include "duckdb/common/vector_operations/unary_executor.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" #include @@ -58,9 +58,10 @@ static idx_t GetResultLength(const char *input_data, idx_t input_length) { if (input_data[i] & 0x80) { // unicode int sz = 0; - auto codepoint = utf8proc_codepoint(input_data + i, sz); - auto converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); - auto new_sz = utf8proc_codepoint_length(converted_codepoint); + auto codepoint = Utf8Proc::UTF8ToCodepoint(input_data + i, sz); + auto converted_codepoint = + IS_UPPER ? Utf8Proc::CodepointToUpper(codepoint) : Utf8Proc::CodepointToLower(codepoint); + auto new_sz = Utf8Proc::CodepointLength(converted_codepoint); D_ASSERT(new_sz >= 0); output_length += UnsafeNumericCast(new_sz); i += UnsafeNumericCast(sz); @@ -79,9 +80,10 @@ static void CaseConvert(const char *input_data, idx_t input_length, char *result if (input_data[i] & 0x80) { // non-ascii character int sz = 0, new_sz = 0; - auto codepoint = utf8proc_codepoint(input_data + i, sz); - auto converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); - auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data); + auto codepoint = Utf8Proc::UTF8ToCodepoint(input_data + i, sz); + auto converted_codepoint = + IS_UPPER ? Utf8Proc::CodepointToUpper(codepoint) : Utf8Proc::CodepointToLower(codepoint); + auto success = Utf8Proc::CodepointToUtf8(converted_codepoint, new_sz, result_data); D_ASSERT(success); (void)success; result_data += new_sz; diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp index 5ad0c9a3..18619a5b 100644 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -4,12 +4,41 @@ #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include namespace duckdb { -static void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { +struct ConcatFunctionData : public FunctionData { + ConcatFunctionData(const LogicalType &return_type_p, bool is_operator_p) + : return_type(return_type_p), is_operator(is_operator_p) { + } + ~ConcatFunctionData() override; + + LogicalType return_type; + + bool is_operator = false; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; +}; + +ConcatFunctionData::~ConcatFunctionData() { +} + +bool ConcatFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return return_type == other.return_type && is_operator == other.is_operator; +} + +unique_ptr ConcatFunctionData::Copy() const { + return make_uniq(return_type, is_operator); +} + +static void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { result.SetVectorType(VectorType::CONSTANT_VECTOR); // iterate over the vectors to count how large the final string will be idx_t constant_lengths = 0; @@ -114,130 +143,234 @@ static void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &resu }); } -static void TemplatedConcatWS(DataChunk &args, const string_t *sep_data, const SelectionVector &sep_sel, - const SelectionVector &rsel, idx_t count, Vector &result) { - vector result_lengths(args.size(), 0); - vector has_results(args.size(), false); - - // we overallocate here, but this is important for static analysis - auto orrified_data = make_unsafe_uniq_array(args.ColumnCount()); +static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto count = args.size(); - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - args.data[col_idx].ToUnifiedFormat(args.size(), orrified_data[col_idx - 1]); + Vector &lhs = args.data[0]; + Vector &rhs = args.data[1]; + if (lhs.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(rhs); + return; + } + if (rhs.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(lhs); + return; } - // first figure out the lengths - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; + UnifiedVectorFormat lhs_data; + UnifiedVectorFormat rhs_data; + lhs.ToUnifiedFormat(count, lhs_data); + rhs.ToUnifiedFormat(count, rhs_data); + auto lhs_entries = UnifiedVectorFormat::GetData(lhs_data); + auto rhs_entries = UnifiedVectorFormat::GetData(rhs_data); - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - result_lengths[ridx] += sep_data[sep_idx].GetSize(); - } - result_lengths[ridx] += input_data[idx].GetSize(); - has_results[ridx] = true; + auto lhs_list_size = ListVector::GetListSize(lhs); + auto rhs_list_size = ListVector::GetListSize(rhs); + auto &lhs_child = ListVector::GetEntry(lhs); + auto &rhs_child = ListVector::GetEntry(rhs); + UnifiedVectorFormat lhs_child_data; + UnifiedVectorFormat rhs_child_data; + lhs_child.ToUnifiedFormat(lhs_list_size, lhs_child_data); + rhs_child.ToUnifiedFormat(rhs_list_size, rhs_child_data); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + idx_t offset = 0; + for (idx_t i = 0; i < count; i++) { + auto lhs_list_index = lhs_data.sel->get_index(i); + auto rhs_list_index = rhs_data.sel->get_index(i); + if (!lhs_data.validity.RowIsValid(lhs_list_index) && !rhs_data.validity.RowIsValid(rhs_list_index)) { + result_validity.SetInvalid(i); + continue; + } + result_entries[i].offset = offset; + result_entries[i].length = 0; + if (lhs_data.validity.RowIsValid(lhs_list_index)) { + const auto &lhs_entry = lhs_entries[lhs_list_index]; + result_entries[i].length += lhs_entry.length; + ListVector::Append(result, lhs_child, *lhs_child_data.sel, lhs_entry.offset + lhs_entry.length, + lhs_entry.offset); } + if (rhs_data.validity.RowIsValid(rhs_list_index)) { + const auto &rhs_entry = rhs_entries[rhs_list_index]; + result_entries[i].length += rhs_entry.length; + ListVector::Append(result, rhs_child, *rhs_child_data.sel, rhs_entry.offset + rhs_entry.length, + rhs_entry.offset); + } + offset += result_entries[i].length; } + D_ASSERT(ListVector::GetListSize(result) == offset); - // first we allocate the empty strings for each of the values - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - // allocate an empty string of the required size - result_data[ridx] = StringVector::EmptyString(result, result_lengths[ridx]); - // we reuse the result_lengths vector to store the currently appended size - result_lengths[ridx] = 0; - has_results[ridx] = false; + if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); } +} - // now that the empty space for the strings has been allocated, perform the concatenation - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - auto sep_size = sep_data[sep_idx].GetSize(); - auto sep_ptr = sep_data[sep_idx].GetData(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], sep_ptr, sep_size); - result_lengths[ridx] += sep_size; - } - auto input_ptr = input_data[idx].GetData(); - auto input_len = input_data[idx].GetSize(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], input_ptr, input_len); - result_lengths[ridx] += input_len; - has_results[ridx] = true; - } +static void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + if (info.return_type.id() == LogicalTypeId::LIST) { + return ListConcatFunction(args, state, result); + } else if (info.is_operator) { + return ConcatOperator(args, state, result); } - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - result_data[ridx].Finalize(); + return StringConcatFunction(args, state, result); +} + +static void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bool is_operator) { + if (is_operator) { + bound_function.arguments[0] = type; + bound_function.arguments[1] = type; + bound_function.return_type = type; + return; } + + for (auto &arg : bound_function.arguments) { + arg = type; + } + bound_function.varargs = type; + bound_function.return_type = type; } -static void ConcatWSFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &separator = args.data[0]; - UnifiedVectorFormat vdata; - separator.ToUnifiedFormat(args.size(), vdata); +static void HandleArrayBinding(ClientContext &context, vector> &arguments) { + if (arguments[1]->return_type.id() != LogicalTypeId::ARRAY && + arguments[1]->return_type.id() != LogicalTypeId::SQLNULL) { + throw BinderException("Cannot concatenate types %s and %s", arguments[0]->return_type.ToString(), + arguments[1]->return_type.ToString()); + } - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - break; - } + // if either argument is an array, we cast it to a list + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); +} + +static unique_ptr HandleListBinding(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, bool is_operator) { + // list_concat only accepts two arguments + D_ASSERT(arguments.size() == 2); + + auto &lhs = arguments[0]->return_type; + auto &rhs = arguments[1]->return_type; + + if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) { + // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list + auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs; + SetArgumentType(bound_function, return_type, is_operator); + return make_uniq(bound_function.return_type, is_operator); } - switch (separator.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (ConstantVector::IsNull(separator)) { - // constant NULL as separator: return constant NULL vector - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - // no null values - auto sel = FlatVector::IncrementalSelectionVector(); - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, *sel, args.size(), result); - return; + if (lhs.id() != LogicalTypeId::LIST || rhs.id() != LogicalTypeId::LIST) { + throw BinderException("Cannot concatenate types %s and %s", lhs.ToString(), rhs.ToString()); } - default: { - // default case: loop over nullmask and create a non-null selection vector - idx_t not_null_count = 0; - SelectionVector not_null_vector(STANDARD_VECTOR_SIZE); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < args.size(); i++) { - if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - result_mask.SetInvalid(i); - } else { - not_null_vector.set_index(not_null_count++, i); - } + + // Resolve list type + LogicalType child_type = LogicalType::SQLNULL; + for (const auto &argument : arguments) { + auto &next_type = ListType::GetChildType(argument->return_type); + if (!LogicalType::TryGetMaxLogicalType(context, child_type, next_type, child_type)) { + throw BinderException("Cannot concatenate lists of types %s[] and %s[] - an explicit cast is required", + child_type.ToString(), next_type.ToString()); } - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, not_null_vector, - not_null_count, result); - return; } + auto list_type = LogicalType::LIST(child_type); + + SetArgumentType(bound_function, list_type, is_operator); + return make_uniq(bound_function.return_type, is_operator); +} + +static void FindFirstTwoArguments(vector> &arguments, LogicalTypeId &first_arg, + LogicalTypeId &second_arg) { + first_arg = arguments[0]->return_type.id(); + second_arg = first_arg; + if (arguments.size() > 1) { + second_arg = arguments[1]->return_type.id(); } } static unique_ptr BindConcatFunction(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - for (auto &arg : bound_function.arguments) { - arg = LogicalType::VARCHAR; + LogicalTypeId first_arg; + LogicalTypeId second_arg; + FindFirstTwoArguments(arguments, first_arg, second_arg); + + if (arguments.size() > 2 && (first_arg == LogicalTypeId::ARRAY || first_arg == LogicalTypeId::LIST)) { + throw BinderException("list_concat only accepts two arguments"); + } + + if (first_arg == LogicalTypeId::ARRAY || second_arg == LogicalTypeId::ARRAY) { + HandleArrayBinding(context, arguments); + FindFirstTwoArguments(arguments, first_arg, second_arg); + } + + if (first_arg == LogicalTypeId::LIST || second_arg == LogicalTypeId::LIST) { + return HandleListBinding(context, bound_function, arguments, false); + } + + // we can now assume that the input is a string or castable to a string + SetArgumentType(bound_function, LogicalType::VARCHAR, false); + return make_uniq(bound_function.return_type, false); +} + +static unique_ptr BindConcatOperator(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(arguments.size() == 2); + + LogicalTypeId lhs; + LogicalTypeId rhs; + FindFirstTwoArguments(arguments, lhs, rhs); + + if (lhs == LogicalTypeId::UNKNOWN || rhs == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + if (lhs == LogicalTypeId::ARRAY || rhs == LogicalTypeId::ARRAY) { + HandleArrayBinding(context, arguments); + FindFirstTwoArguments(arguments, lhs, rhs); + } + + if (lhs == LogicalTypeId::LIST || rhs == LogicalTypeId::LIST) { + return HandleListBinding(context, bound_function, arguments, true); + } + + LogicalType return_type; + if (lhs == LogicalTypeId::BLOB && rhs == LogicalTypeId::BLOB) { + return_type = LogicalType::BLOB; + } else { + return_type = LogicalType::VARCHAR; } - bound_function.varargs = LogicalType::VARCHAR; - return nullptr; + + // we can now assume that the input is a string or castable to a string + SetArgumentType(bound_function, return_type, true); + return make_uniq(bound_function.return_type, true); +} + +static unique_ptr ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + D_ASSERT(child_stats.size() == 2); + + auto &left_stats = child_stats[0]; + auto &right_stats = child_stats[1]; + + auto stats = left_stats.ToUnique(); + stats->Merge(right_stats); + + return stats; +} + +ScalarFunction ListConcatFun::GetFunction() { + // The arguments and return types are set in the binder function. + auto fun = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, + LogicalType::LIST(LogicalType::ANY), ConcatFunction, BindConcatFunction, nullptr, + ListConcatStats); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +void ListConcatFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction({"list_concat", "list_cat", "array_concat", "array_cat"}, GetFunction()); } void ConcatFun::RegisterFunction(BuiltinFunctions &set) { @@ -249,34 +382,17 @@ void ConcatFun::RegisterFunction(BuiltinFunctions &set) { // i.e. NULL || 'hello' = NULL // the concat function, however, treats NULL values as an empty string // i.e. concat(NULL, 'hello') = 'hello' - // concat_ws functions similarly to the concat function, except the result is NULL if the separator is NULL - // if the separator is not NULL, however, NULL values are counted as empty string - // there is one separate rule: there are no separators added between NULL values - // so the NULL value and empty string are different! - // e.g.: - // concat_ws(',', NULL, NULL) = "" - // concat_ws(',', '', '') = "," + ScalarFunction concat = - ScalarFunction("concat", {LogicalType::ANY}, LogicalType::VARCHAR, ConcatFunction, BindConcatFunction); + ScalarFunction("concat", {LogicalType::ANY}, LogicalType::ANY, ConcatFunction, BindConcatFunction); concat.varargs = LogicalType::ANY; concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; set.AddFunction(concat); - ScalarFunctionSet concat_op("||"); - concat_op.AddFunction( - ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::VARCHAR, ConcatOperator, BindConcatFunction)); - concat_op.AddFunction(ScalarFunction({LogicalType::BLOB, LogicalType::BLOB}, LogicalType::BLOB, ConcatOperator)); - concat_op.AddFunction(ListConcatFun::GetFunction()); - for (auto &fun : concat_op.functions) { - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - } + ScalarFunction concat_op = ScalarFunction("||", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, + ConcatFunction, BindConcatOperator); + concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; set.AddFunction(concat_op); - - ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::ANY}, - LogicalType::VARCHAR, ConcatWSFunction, BindConcatFunction); - concat_ws.varargs = LogicalType::ANY; - concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - set.AddFunction(concat_ws); } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/concat_ws.cpp b/src/duckdb/src/function/scalar/string/concat_ws.cpp new file mode 100644 index 00000000..7689738c --- /dev/null +++ b/src/duckdb/src/function/scalar/string/concat_ws.cpp @@ -0,0 +1,149 @@ +#include "duckdb/function/scalar/string_functions.hpp" + +#include + +namespace duckdb { + +static void TemplatedConcatWS(DataChunk &args, const string_t *sep_data, const SelectionVector &sep_sel, + const SelectionVector &rsel, idx_t count, Vector &result) { + vector result_lengths(args.size(), 0); + vector has_results(args.size(), false); + + // we overallocate here, but this is important for static analysis + auto orrified_data = make_unsafe_uniq_array_uninitialized(args.ColumnCount()); + + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + args.data[col_idx].ToUnifiedFormat(args.size(), orrified_data[col_idx - 1]); + } + + // first figure out the lengths + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &idata = orrified_data[col_idx - 1]; + + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + auto sep_idx = sep_sel.get_index(ridx); + auto idx = idata.sel->get_index(ridx); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + if (has_results[ridx]) { + result_lengths[ridx] += sep_data[sep_idx].GetSize(); + } + result_lengths[ridx] += input_data[idx].GetSize(); + has_results[ridx] = true; + } + } + + // first we allocate the empty strings for each of the values + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + // allocate an empty string of the required size + result_data[ridx] = StringVector::EmptyString(result, result_lengths[ridx]); + // we reuse the result_lengths vector to store the currently appended size + result_lengths[ridx] = 0; + has_results[ridx] = false; + } + + // now that the empty space for the strings has been allocated, perform the concatenation + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &idata = orrified_data[col_idx - 1]; + auto input_data = UnifiedVectorFormat::GetData(idata); + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + auto sep_idx = sep_sel.get_index(ridx); + auto idx = idata.sel->get_index(ridx); + if (!idata.validity.RowIsValid(idx)) { + continue; + } + if (has_results[ridx]) { + auto sep_size = sep_data[sep_idx].GetSize(); + auto sep_ptr = sep_data[sep_idx].GetData(); + memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], sep_ptr, sep_size); + result_lengths[ridx] += sep_size; + } + auto input_ptr = input_data[idx].GetData(); + auto input_len = input_data[idx].GetSize(); + memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], input_ptr, input_len); + result_lengths[ridx] += input_len; + has_results[ridx] = true; + } + } + for (idx_t i = 0; i < count; i++) { + auto ridx = rsel.get_index(i); + result_data[ridx].Finalize(); + } +} + +static void ConcatWSFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &separator = args.data[0]; + UnifiedVectorFormat vdata; + separator.ToUnifiedFormat(args.size(), vdata); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + break; + } + } + switch (separator.GetVectorType()) { + case VectorType::CONSTANT_VECTOR: { + if (ConstantVector::IsNull(separator)) { + // constant NULL as separator: return constant NULL vector + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + // no null values + auto sel = FlatVector::IncrementalSelectionVector(); + TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, *sel, args.size(), result); + return; + } + default: { + // default case: loop over nullmask and create a non-null selection vector + idx_t not_null_count = 0; + SelectionVector not_null_vector(STANDARD_VECTOR_SIZE); + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < args.size(); i++) { + if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { + result_mask.SetInvalid(i); + } else { + not_null_vector.set_index(not_null_count++, i); + } + } + TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, not_null_vector, + not_null_count, result); + return; + } + } +} + +static unique_ptr BindConcatWSFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + for (auto &arg : bound_function.arguments) { + arg = LogicalType::VARCHAR; + } + bound_function.varargs = LogicalType::VARCHAR; + return nullptr; +} + +void ConcatWSFun::RegisterFunction(BuiltinFunctions &set) { + // concat_ws functions similarly to the concat function, except the result is NULL if the separator is NULL + // if the separator is not NULL, however, NULL values are counted as empty string + // there is one separate rule: there are no separators added between NULL values, + // so the NULL value and empty string are different! + // e.g.: + // concat_ws(',', NULL, NULL) = "" + // concat_ws(',', '', '') = "," + + ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::ANY}, + LogicalType::VARCHAR, ConcatWSFunction, BindConcatWSFunction); + concat_ws.varargs = LogicalType::ANY; + concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(concat_ws); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/contains.cpp b/src/duckdb/src/function/scalar/string/contains.cpp index 3e24ed6a..f1c3f8dc 100644 --- a/src/duckdb/src/function/scalar/string/contains.cpp +++ b/src/duckdb/src/function/scalar/string/contains.cpp @@ -3,7 +3,8 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" - +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/core_functions/scalar/map_functions.hpp" namespace duckdb { template @@ -151,15 +152,25 @@ struct ContainsOperator { } }; -ScalarFunction ContainsFun::GetFunction() { - return ScalarFunction("contains", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); +ScalarFunctionSet ContainsFun::GetFunctions() { + auto string_fun = GetStringContains(); + auto list_fun = ListContainsFun::GetFunction(); + auto map_fun = MapContainsFun::GetFunction(); + ScalarFunctionSet set("contains"); + set.AddFunction(string_fun); + set.AddFunction(list_fun); + set.AddFunction(map_fun); + return set; +} + +ScalarFunction ContainsFun::GetStringContains() { + ScalarFunction string_fun("contains", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction); + return string_fun; } void ContainsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(GetFunction()); + set.AddFunction(GetFunctions()); } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp index edca8d8a..b8271178 100644 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -403,11 +403,11 @@ bool ILikeOperatorFunction(string_t &str, string_t &pattern, char escape = '\0') // lowercase both the str and the pattern idx_t str_llength = LowerFun::LowerLength(str_data, str_size); - auto str_ldata = make_unsafe_uniq_array(str_llength); + auto str_ldata = make_unsafe_uniq_array_uninitialized(str_llength); LowerFun::LowerCase(str_data, str_size, str_ldata.get()); idx_t pat_llength = LowerFun::LowerLength(pat_data, pat_size); - auto pat_ldata = make_unsafe_uniq_array(pat_llength); + auto pat_ldata = make_unsafe_uniq_array_uninitialized(pat_llength); LowerFun::LowerCase(pat_data, pat_size, pat_ldata.get()); string_t str_lcase(str_ldata.get(), UnsafeNumericCast(str_llength)); string_t pat_lcase(pat_ldata.get(), UnsafeNumericCast(pat_llength)); diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp index 76f4859f..f7ff13f3 100644 --- a/src/duckdb/src/function/scalar/string/substring.cpp +++ b/src/duckdb/src/function/scalar/string/substring.cpp @@ -224,11 +224,7 @@ string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t if (offset < 0) { // negative offset, this case is more difficult // we first need to count the number of characters in the string - idx_t num_characters = 0; - utf8proc_grapheme_callback(input_data, input_size, [&](size_t start, size_t end) { - num_characters++; - return true; - }); + idx_t num_characters = Utf8Proc::GraphemeCount(input_data, input_size); // now call substring start and end again, but with the number of unicode characters this time SubstringStartEnd(UnsafeNumericCast(num_characters), offset, length, start, end); } @@ -236,16 +232,15 @@ string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t // now scan the graphemes of the string to find the positions of the start and end characters int64_t current_character = 0; idx_t start_pos = DConstants::INVALID_INDEX, end_pos = input_size; - utf8proc_grapheme_callback(input_data, input_size, [&](size_t gstart, size_t gend) { + for (auto cluster : Utf8Proc::GraphemeClusters(input_data, input_size)) { if (current_character == start) { - start_pos = gstart; + start_pos = cluster.start; } else if (current_character == end) { - end_pos = gstart; - return false; + end_pos = cluster.start; + break; } current_character++; - return true; - }); + } if (start_pos == DConstants::INVALID_INDEX) { return SubstringEmptyString(result); } diff --git a/src/duckdb/src/function/scalar/string_functions.cpp b/src/duckdb/src/function/scalar/string_functions.cpp index 88d7b716..e3ccbe0a 100644 --- a/src/duckdb/src/function/scalar/string_functions.cpp +++ b/src/duckdb/src/function/scalar/string_functions.cpp @@ -7,6 +7,7 @@ void BuiltinFunctions::RegisterStringFunctions() { Register(); Register(); Register(); + Register(); Register(); Register(); Register(); diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp index 2572cb51..b4ef1c8d 100644 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -92,13 +92,13 @@ static unique_ptr StructExtractBind(ClientContext &context, Scalar for (auto &struct_child : struct_children) { candidates.push_back(struct_child.first); } - auto closest_settings = StringUtil::TopNLevenshtein(candidates, key); + auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); } bound_function.return_type = std::move(return_type); - return make_uniq(key_index); + return StructExtractFun::GetBindData(key_index); } static unique_ptr StructExtractBindIndex(ClientContext &context, ScalarFunction &bound_function, @@ -134,7 +134,7 @@ static unique_ptr StructExtractBindIndex(ClientContext &context, S index, struct_children.size()); } bound_function.return_type = struct_children[NumericCast(index - 1)].second; - return make_uniq(NumericCast(index - 1)); + return StructExtractFun::GetBindData(NumericCast(index - 1)); } static unique_ptr PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -146,6 +146,10 @@ static unique_ptr PropagateStructExtractStats(ClientContext &con return struct_child_stats[info.index].ToUnique(); } +unique_ptr StructExtractFun::GetBindData(idx_t index) { + return make_uniq(index); +} + ScalarFunction StructExtractFun::KeyExtractFunction() { return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index 035a1e33..92fea109 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -81,7 +81,7 @@ static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, V auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); local_state.allocator.Reset(); - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); + D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); D_ASSERT(input.data.size() == 1); D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); auto aligned_state_size = AlignValue(bind_data.state_size); @@ -101,7 +101,7 @@ static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, V } else { // create a dummy state because finalize does not understand NULLs in its input // we put the NULL back in explicitly below - bind_data.aggr.initialize(data_ptr_cast(target_ptr)); + bind_data.aggr.initialize(bind_data.aggr, data_ptr_cast(target_ptr)); } state_vec_ptr[i] = data_ptr_cast(target_ptr); } @@ -122,7 +122,7 @@ static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Ve auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); local_state.allocator.Reset(); - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); + D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); D_ASSERT(input.data.size() == 2); D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); @@ -248,14 +248,14 @@ static unique_ptr BindAggregateState(ClientContext &context, Scala bound_function.return_type = arg_return_type; } - return make_uniq(bound_aggr, bound_aggr.state_size()); + return make_uniq(bound_aggr, bound_aggr.state_size(bound_aggr)); } static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { D_ASSERT(offset == 0); auto &bind_data = aggr_input_data.bind_data->Cast(); - auto state_size = bind_data.aggregate->function.state_size(); + auto state_size = bind_data.aggregate->function.state_size(bind_data.aggregate->function); auto blob_ptr = FlatVector::GetData(result); auto addresses_ptr = FlatVector::GetData(state); for (idx_t row_idx = 0; row_idx < count; row_idx++) { diff --git a/src/duckdb/src/function/scalar_function.cpp b/src/duckdb/src/function/scalar_function.cpp index 6b02a00a..75d74cf5 100644 --- a/src/duckdb/src/function/scalar_function.cpp +++ b/src/duckdb/src/function/scalar_function.cpp @@ -5,6 +5,9 @@ namespace duckdb { FunctionLocalState::~FunctionLocalState() { } +ScalarFunctionInfo::~ScalarFunctionInfo() { +} + ScalarFunction::ScalarFunction(string name, vector arguments, LogicalType return_type, scalar_function_t function, bind_scalar_function_t bind, dependency_function_t dependency, function_statistics_t statistics, @@ -13,8 +16,8 @@ ScalarFunction::ScalarFunction(string name, vector arguments, Logic : BaseScalarFunction(std::move(name), std::move(arguments), std::move(return_type), side_effects, std::move(varargs), null_handling), function(std::move(function)), bind(bind), init_local_state(init_local_state), dependency(dependency), - statistics(statistics), bind_lambda(bind_lambda), get_modified_databases(nullptr), serialize(nullptr), - deserialize(nullptr) { + statistics(statistics), bind_lambda(bind_lambda), bind_expression(nullptr), get_modified_databases(nullptr), + serialize(nullptr), deserialize(nullptr) { } ScalarFunction::ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, diff --git a/src/duckdb/src/function/scalar_macro_function.cpp b/src/duckdb/src/function/scalar_macro_function.cpp index 07f7b788..9f3d79c3 100644 --- a/src/duckdb/src/function/scalar_macro_function.cpp +++ b/src/duckdb/src/function/scalar_macro_function.cpp @@ -42,11 +42,11 @@ void RemoveQualificationRecursive(unique_ptr &expr) { } } -string ScalarMacroFunction::ToSQL(const string &schema, const string &name) const { +string ScalarMacroFunction::ToSQL() const { // In case of nested macro's we need to fix it a bit auto expression_copy = expression->Copy(); RemoveQualificationRecursive(expression_copy); - return MacroFunction::ToSQL(schema, name) + StringUtil::Format("(%s);", expression_copy->ToString()); + return MacroFunction::ToSQL() + StringUtil::Format("(%s)", expression_copy->ToString()); } } // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp index b7054fd0..16b5c5e4 100644 --- a/src/duckdb/src/function/table/arrow.cpp +++ b/src/duckdb/src/function/table/arrow.cpp @@ -7,16 +7,103 @@ #include "duckdb/common/types/date.hpp" #include "duckdb/common/types/vector_buffer.hpp" #include "duckdb/function/table/arrow.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/function/table/arrow/arrow_type_info.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "utf8proc_wrapper.hpp" +#include "duckdb/common/arrow/schema_metadata.hpp" namespace duckdb { +static unique_ptr CreateListType(ArrowSchema &child, ArrowVariableSizeType size_type, bool view) { + auto child_type = ArrowTableFunction::GetArrowLogicalType(child); + + unique_ptr type_info; + auto type = LogicalType::LIST(child_type->GetDuckType()); + if (view) { + type_info = ArrowListInfo::ListView(std::move(child_type), size_type); + } else { + type_info = ArrowListInfo::List(std::move(child_type), size_type); + } + return make_uniq(type, std::move(type_info)); +} + +static unique_ptr GetArrowExtensionType(const ArrowSchemaMetadata &extension_type, const string &format) { + auto arrow_extension = extension_type.GetExtensionName(); + // Check for arrow canonical extensions + if (arrow_extension == "arrow.uuid") { + if (format != "w:16") { + throw InvalidInputException( + "arrow.uuid must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly defined as: %s", + format); + } + return make_uniq(LogicalType::UUID); + } else if (arrow_extension == "arrow.json") { + if (format == "u") { + return make_uniq(LogicalType::JSON(), make_uniq(ArrowVariableSizeType::NORMAL)); + } else if (format == "U") { + return make_uniq(LogicalType::JSON(), + make_uniq(ArrowVariableSizeType::SUPER_SIZE)); + } else if (format == "vu") { + return make_uniq(LogicalType::JSON(), make_uniq(ArrowVariableSizeType::VIEW)); + } else { + throw InvalidInputException("arrow.json must be of a varchar format (i.e., \'u\',\'U\' or \'vu\'). It is " + "incorrectly defined as: %s", + format); + } + } + // Check for DuckDB canonical extensions + else if (arrow_extension == "duckdb.hugeint") { + if (format != "w:16") { + throw InvalidInputException("duckdb.hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It " + "is incorrectly defined as: %s", + format); + } + return make_uniq(LogicalType::HUGEINT); + + } else if (arrow_extension == "duckdb.uhugeint") { + if (format != "w:16") { + throw InvalidInputException("duckdb.hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It " + "is incorrectly defined as: %s", + format); + } + return make_uniq(LogicalType::UHUGEINT); + } else if (arrow_extension == "duckdb.time_tz") { + if (format != "w:8") { + throw InvalidInputException("duckdb.time_tz must be a fixed-size binary of 8 bytes (i.e., \'w:8\'). It " + "is incorrectly defined as: %s", + format); + } + return make_uniq(LogicalType::TIME_TZ, + make_uniq(ArrowDateTimeType::MICROSECONDS)); + } else if (arrow_extension == "duckdb.bit") { + if (format != "z" && format != "Z") { + throw InvalidInputException("duckdb.bit must be a blob (i.e., \'z\' or \'Z\'). It " + "is incorrectly defined as: %s", + format); + } else if (format == "z") { + auto type_info = make_uniq(ArrowVariableSizeType::NORMAL); + return make_uniq(LogicalType::BIT, std::move(type_info)); + } + auto type_info = make_uniq(ArrowVariableSizeType::SUPER_SIZE); + return make_uniq(LogicalType::BIT, std::move(type_info)); + + } else { + throw NotImplementedException( + "Arrow Type with extension name: %s and format: %s, is not currently supported in DuckDB ", arrow_extension, + format); + } +} static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema) { auto format = string(schema.format); + // Let's first figure out if this type is an extension type + ArrowSchemaMetadata schema_metadata(schema.metadata); + if (schema_metadata.HasExtension()) { + return GetArrowExtensionType(schema_metadata, format); + } + // If not, we just check the format itself if (format == "n") { return make_uniq(LogicalType::SQLNULL); } else if (format == "b") { @@ -67,9 +154,12 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema } return make_uniq(LogicalType::DECIMAL(NumericCast(width), NumericCast(scale))); } else if (format == "u") { - return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::NORMAL); + return make_uniq(LogicalType::VARCHAR, make_uniq(ArrowVariableSizeType::NORMAL)); } else if (format == "U") { - return make_uniq(LogicalType::VARCHAR, ArrowVariableSizeType::SUPER_SIZE); + return make_uniq(LogicalType::VARCHAR, + make_uniq(ArrowVariableSizeType::SUPER_SIZE)); + } else if (format == "vu") { + return make_uniq(LogicalType::VARCHAR, make_uniq(ArrowVariableSizeType::VIEW)); } else if (format == "tsn:") { return make_uniq(LogicalTypeId::TIMESTAMP_NS); } else if (format == "tsu:") { @@ -79,50 +169,51 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema } else if (format == "tss:") { return make_uniq(LogicalTypeId::TIMESTAMP_SEC); } else if (format == "tdD") { - return make_uniq(LogicalType::DATE, ArrowDateTimeType::DAYS); + return make_uniq(LogicalType::DATE, make_uniq(ArrowDateTimeType::DAYS)); } else if (format == "tdm") { - return make_uniq(LogicalType::DATE, ArrowDateTimeType::MILLISECONDS); + return make_uniq(LogicalType::DATE, make_uniq(ArrowDateTimeType::MILLISECONDS)); } else if (format == "tts") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::SECONDS); + return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::SECONDS)); } else if (format == "ttm") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::MILLISECONDS); + return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::MILLISECONDS)); } else if (format == "ttu") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::MICROSECONDS); + return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::MICROSECONDS)); } else if (format == "ttn") { - return make_uniq(LogicalType::TIME, ArrowDateTimeType::NANOSECONDS); + return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::NANOSECONDS)); } else if (format == "tDs") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::SECONDS); + return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::SECONDS)); } else if (format == "tDm") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MILLISECONDS); + return make_uniq(LogicalType::INTERVAL, + make_uniq(ArrowDateTimeType::MILLISECONDS)); } else if (format == "tDu") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MICROSECONDS); + return make_uniq(LogicalType::INTERVAL, + make_uniq(ArrowDateTimeType::MICROSECONDS)); } else if (format == "tDn") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::NANOSECONDS); + return make_uniq(LogicalType::INTERVAL, + make_uniq(ArrowDateTimeType::NANOSECONDS)); } else if (format == "tiD") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::DAYS); + return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::DAYS)); } else if (format == "tiM") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTHS); + return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::MONTHS)); } else if (format == "tin") { - return make_uniq(LogicalType::INTERVAL, ArrowDateTimeType::MONTH_DAY_NANO); + return make_uniq(LogicalType::INTERVAL, + make_uniq(ArrowDateTimeType::MONTH_DAY_NANO)); } else if (format == "+l") { - auto child_type = ArrowTableFunction::GetArrowLogicalType(*schema.children[0]); - auto list_type = - make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::NORMAL); - list_type->AddChild(std::move(child_type)); - return list_type; + return CreateListType(*schema.children[0], ArrowVariableSizeType::NORMAL, false); } else if (format == "+L") { - auto child_type = ArrowTableFunction::GetArrowLogicalType(*schema.children[0]); - auto list_type = - make_uniq(LogicalType::LIST(child_type->GetDuckType()), ArrowVariableSizeType::SUPER_SIZE); - list_type->AddChild(std::move(child_type)); - return list_type; + return CreateListType(*schema.children[0], ArrowVariableSizeType::SUPER_SIZE, false); + } else if (format == "+vl") { + return CreateListType(*schema.children[0], ArrowVariableSizeType::NORMAL, true); + } else if (format == "+vL") { + return CreateListType(*schema.children[0], ArrowVariableSizeType::SUPER_SIZE, true); } else if (format[0] == '+' && format[1] == 'w') { std::string parameters = format.substr(format.find(':') + 1); auto fixed_size = NumericCast(std::stoi(parameters)); auto child_type = ArrowTableFunction::GetArrowLogicalType(*schema.children[0]); - auto list_type = make_uniq(LogicalType::ARRAY(child_type->GetDuckType(), fixed_size), fixed_size); - list_type->AddChild(std::move(child_type)); - return list_type; + + auto array_type = LogicalType::ARRAY(child_type->GetDuckType(), fixed_size); + auto type_info = make_uniq(std::move(child_type), fixed_size); + return make_uniq(array_type, std::move(type_info)); } else if (format == "+s") { child_list_t child_types; vector> children; @@ -134,8 +225,8 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema children.emplace_back(ArrowTableFunction::GetArrowLogicalType(*schema.children[type_idx])); child_types.emplace_back(schema.children[type_idx]->name, children.back()->GetDuckType()); } - auto struct_type = make_uniq(LogicalType::STRUCT(std::move(child_types))); - struct_type->AssignChildren(std::move(children)); + auto type_info = make_uniq(std::move(children)); + auto struct_type = make_uniq(LogicalType::STRUCT(std::move(child_types)), std::move(type_info)); return struct_type; } else if (format[0] == '+' && format[1] == 'u') { if (format[2] != 's') { @@ -159,8 +250,8 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema members.emplace_back(type->name, children.back()->GetDuckType()); } - auto union_type = make_uniq(LogicalType::UNION(members)); - union_type->AssignChildren(std::move(children)); + auto type_info = make_uniq(std::move(children)); + auto union_type = make_uniq(LogicalType::UNION(members), std::move(type_info)); return union_type; } else if (format == "+r") { child_list_t members; @@ -175,8 +266,8 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema members.emplace_back(type->name, children.back()->GetDuckType()); } - auto struct_type = make_uniq(LogicalType::STRUCT(members)); - struct_type->AssignChildren(std::move(children)); + auto type_info = make_uniq(std::move(children)); + auto struct_type = make_uniq(LogicalType::STRUCT(members), std::move(type_info)); struct_type->SetRunEndEncoded(); return struct_type; } else if (format == "+m") { @@ -184,43 +275,46 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema D_ASSERT(arrow_struct_type.n_children == 2); auto key_type = ArrowTableFunction::GetArrowLogicalType(*arrow_struct_type.children[0]); auto value_type = ArrowTableFunction::GetArrowLogicalType(*arrow_struct_type.children[1]); - auto map_type = make_uniq(LogicalType::MAP(key_type->GetDuckType(), value_type->GetDuckType()), - ArrowVariableSizeType::NORMAL); child_list_t key_value; key_value.emplace_back(std::make_pair("key", key_type->GetDuckType())); key_value.emplace_back(std::make_pair("value", value_type->GetDuckType())); - auto inner_struct = - make_uniq(LogicalType::STRUCT(std::move(key_value)), ArrowVariableSizeType::NORMAL); + auto map_type = LogicalType::MAP(key_type->GetDuckType(), value_type->GetDuckType()); vector> children; children.reserve(2); children.push_back(std::move(key_type)); children.push_back(std::move(value_type)); - inner_struct->AssignChildren(std::move(children)); - map_type->AddChild(std::move(inner_struct)); - return map_type; + auto inner_struct = make_uniq(LogicalType::STRUCT(std::move(key_value)), + make_uniq(std::move(children))); + auto map_type_info = ArrowListInfo::List(std::move(inner_struct), ArrowVariableSizeType::NORMAL); + return make_uniq(map_type, std::move(map_type_info)); } else if (format == "z") { - return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::NORMAL); + auto type_info = make_uniq(ArrowVariableSizeType::NORMAL); + return make_uniq(LogicalType::BLOB, std::move(type_info)); } else if (format == "Z") { - return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::SUPER_SIZE); + auto type_info = make_uniq(ArrowVariableSizeType::SUPER_SIZE); + return make_uniq(LogicalType::BLOB, std::move(type_info)); } else if (format[0] == 'w') { - std::string parameters = format.substr(format.find(':') + 1); + string parameters = format.substr(format.find(':') + 1); auto fixed_size = NumericCast(std::stoi(parameters)); - return make_uniq(LogicalType::BLOB, fixed_size); + auto type_info = make_uniq(fixed_size); + return make_uniq(LogicalType::BLOB, std::move(type_info)); } else if (format[0] == 't' && format[1] == 's') { // Timestamp with Timezone // TODO right now we just get the UTC value. We probably want to support this properly in the future + unique_ptr type_info; if (format[2] == 'n') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::NANOSECONDS); + type_info = make_uniq(ArrowDateTimeType::NANOSECONDS); } else if (format[2] == 'u') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MICROSECONDS); + type_info = make_uniq(ArrowDateTimeType::MICROSECONDS); } else if (format[2] == 'm') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::MILLISECONDS); + type_info = make_uniq(ArrowDateTimeType::MILLISECONDS); } else if (format[2] == 's') { - return make_uniq(LogicalType::TIMESTAMP_TZ, ArrowDateTimeType::SECONDS); + type_info = make_uniq(ArrowDateTimeType::SECONDS); } else { throw NotImplementedException(" Timestamptz precision of not accepted"); } + return make_uniq(LogicalType::TIMESTAMP_TZ, std::move(type_info)); } else { throw NotImplementedException("Unsupported Internal Arrow Type %s", format); } @@ -335,7 +429,7 @@ unique_ptr ArrowTableFunction::ArrowScanInitGlobal(Cli auto result = make_uniq(); result->stream = ProduceArrowScan(bind_data, input.column_ids, input.filters.get()); result->max_threads = ArrowScanMaxThreads(context, input.bind_data.get()); - if (input.CanRemoveFilterColumns()) { + if (!input.projection_ids.empty()) { result->projection_ids = input.projection_ids; for (const auto &col_idx : input.column_ids) { if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { @@ -356,7 +450,7 @@ ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunc auto result = make_uniq(std::move(current_chunk)); result->column_ids = input.column_ids; result->filters = input.filters.get(); - if (input.CanRemoveFilterColumns()) { + if (!input.projection_ids.empty()) { auto &asgs = global_state_p->Cast(); result->all_columns.Initialize(context, asgs.scanned_types); } @@ -414,6 +508,53 @@ idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const Funct return state.batch_index; } +bool ArrowTableFunction::ArrowPushdownType(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + return true; + case LogicalTypeId::DECIMAL: { + switch (type.InternalType()) { + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + return true; + default: + return false; + } + } break; + case LogicalTypeId::STRUCT: { + auto struct_types = StructType::GetChildTypes(type); + for (auto &struct_type : struct_types) { + if (!ArrowPushdownType(struct_type.second)) { + return false; + } + } + return true; + } + default: + return false; + } +} + void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { TableFunction arrow("arrow_scan", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); @@ -422,6 +563,7 @@ void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { arrow.projection_pushdown = true; arrow.filter_pushdown = true; arrow.filter_prune = true; + arrow.supports_pushdown_type = ArrowPushdownType; set.AddFunction(arrow); TableFunction arrow_dumb("arrow_scan_dumb", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, diff --git a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp index ae841179..b434b99c 100644 --- a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp +++ b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp @@ -13,15 +13,6 @@ const arrow_column_map_t &ArrowTableType::GetColumns() const { return arrow_convert_data; } -void ArrowType::AddChild(unique_ptr child) { - children.emplace_back(std::move(child)); -} - -void ArrowType::AssignChildren(vector> children) { - D_ASSERT(this->children.empty()); - this->children = std::move(children); -} - void ArrowType::SetDictionary(unique_ptr dictionary) { D_ASSERT(!this->dictionary_type); dictionary_type = std::move(dictionary); @@ -37,8 +28,12 @@ const ArrowType &ArrowType::GetDictionary() const { } void ArrowType::SetRunEndEncoded() { - D_ASSERT(children.size() == 2); - auto actual_type = children[1]->GetDuckType(); + D_ASSERT(type_info); + D_ASSERT(type_info->type == ArrowTypeInfoType::STRUCT); + auto &struct_info = type_info->Cast(); + D_ASSERT(struct_info.ChildCount() == 2); + + auto actual_type = struct_info.GetChild(1).GetDuckType(); // Override the duckdb type to the actual type type = actual_type; run_end_encoded = true; @@ -60,29 +55,33 @@ LogicalType ArrowType::GetDuckType(bool use_dictionary) const { auto id = type.id(); switch (id) { case LogicalTypeId::STRUCT: { + auto &struct_info = type_info->Cast(); child_list_t new_children; - for (idx_t i = 0; i < children.size(); i++) { - auto &child = children[i]; + for (idx_t i = 0; i < struct_info.ChildCount(); i++) { + auto &child = struct_info.GetChild(i); auto &child_name = StructType::GetChildName(type, i); - new_children.emplace_back(std::make_pair(child_name, child->GetDuckType(true))); + new_children.emplace_back(std::make_pair(child_name, child.GetDuckType(true))); } return LogicalType::STRUCT(std::move(new_children)); } case LogicalTypeId::LIST: { - auto &child = children[0]; - return LogicalType::LIST(child->GetDuckType(true)); + auto &list_info = type_info->Cast(); + auto &child = list_info.GetChild(); + return LogicalType::LIST(child.GetDuckType(true)); } case LogicalTypeId::MAP: { - auto &struct_child = children[0]; - auto struct_type = struct_child->GetDuckType(true); + auto &list_info = type_info->Cast(); + auto &struct_child = list_info.GetChild(); + auto struct_type = struct_child.GetDuckType(true); return LogicalType::MAP(StructType::GetChildType(struct_type, 0), StructType::GetChildType(struct_type, 1)); } case LogicalTypeId::UNION: { + auto &union_info = type_info->Cast(); child_list_t new_children; - for (idx_t i = 0; i < children.size(); i++) { - auto &child = children[i]; + for (idx_t i = 0; i < union_info.ChildCount(); i++) { + auto &child = union_info.GetChild(i); auto &child_name = UnionType::GetMemberName(type, i); - new_children.emplace_back(std::make_pair(child_name, child->GetDuckType(true))); + new_children.emplace_back(std::make_pair(child_name, child.GetDuckType(true))); } return LogicalType::UNION(std::move(new_children)); } @@ -92,22 +91,4 @@ LogicalType ArrowType::GetDuckType(bool use_dictionary) const { } } -ArrowVariableSizeType ArrowType::GetSizeType() const { - return size_type; -} - -ArrowDateTimeType ArrowType::GetDateTimeType() const { - return date_time_precision; -} - -const ArrowType &ArrowType::operator[](idx_t index) const { - D_ASSERT(index < children.size()); - return *children[index]; -} - -idx_t ArrowType::FixedSize() const { - D_ASSERT(size_type == ArrowVariableSizeType::FIXED_SIZE); - return fixed_size; -} - } // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_type_info.cpp b/src/duckdb/src/function/table/arrow/arrow_type_info.cpp new file mode 100644 index 00000000..e012f1b5 --- /dev/null +++ b/src/duckdb/src/function/table/arrow/arrow_type_info.cpp @@ -0,0 +1,135 @@ +#include "duckdb/function/table/arrow/arrow_type_info.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// ArrowTypeInfo +//===--------------------------------------------------------------------===// + +ArrowTypeInfo::ArrowTypeInfo(ArrowTypeInfoType type) : type(type) { +} + +ArrowTypeInfo::~ArrowTypeInfo() { +} + +//===--------------------------------------------------------------------===// +// ArrowStructInfo +//===--------------------------------------------------------------------===// + +ArrowStructInfo::ArrowStructInfo(vector> children) + : ArrowTypeInfo(ArrowTypeInfoType::STRUCT), children(std::move(children)) { +} + +idx_t ArrowStructInfo::ChildCount() const { + return children.size(); +} + +ArrowStructInfo::~ArrowStructInfo() { +} + +const ArrowType &ArrowStructInfo::GetChild(idx_t index) const { + D_ASSERT(index < children.size()); + return *children[index]; +} + +const vector> &ArrowStructInfo::GetChildren() const { + return children; +} + +//===--------------------------------------------------------------------===// +// ArrowDateTimeInfo +//===--------------------------------------------------------------------===// + +ArrowDateTimeInfo::ArrowDateTimeInfo(ArrowDateTimeType size) + : ArrowTypeInfo(ArrowTypeInfoType::DATE_TIME), size_type(size) { +} + +ArrowDateTimeInfo::~ArrowDateTimeInfo() { +} + +ArrowDateTimeType ArrowDateTimeInfo::GetDateTimeType() const { + return size_type; +} + +//===--------------------------------------------------------------------===// +// ArrowStringInfo +//===--------------------------------------------------------------------===// + +ArrowStringInfo::ArrowStringInfo(ArrowVariableSizeType size) + : ArrowTypeInfo(ArrowTypeInfoType::STRING), size_type(size), fixed_size(0) { + D_ASSERT(size != ArrowVariableSizeType::FIXED_SIZE); +} + +ArrowStringInfo::~ArrowStringInfo() { +} + +ArrowStringInfo::ArrowStringInfo(idx_t fixed_size) + : ArrowTypeInfo(ArrowTypeInfoType::STRING), size_type(ArrowVariableSizeType::FIXED_SIZE), fixed_size(fixed_size) { +} + +ArrowVariableSizeType ArrowStringInfo::GetSizeType() const { + return size_type; +} + +idx_t ArrowStringInfo::FixedSize() const { + D_ASSERT(size_type == ArrowVariableSizeType::FIXED_SIZE); + return fixed_size; +} + +//===--------------------------------------------------------------------===// +// ArrowListInfo +//===--------------------------------------------------------------------===// + +ArrowListInfo::ArrowListInfo(unique_ptr child, ArrowVariableSizeType size) + : ArrowTypeInfo(ArrowTypeInfoType::LIST), size_type(size), child(std::move(child)) { +} + +ArrowListInfo::~ArrowListInfo() { +} + +unique_ptr ArrowListInfo::ListView(unique_ptr child, ArrowVariableSizeType size) { + D_ASSERT(size == ArrowVariableSizeType::SUPER_SIZE || size == ArrowVariableSizeType::NORMAL); + auto list_info = unique_ptr(new ArrowListInfo(std::move(child), size)); + list_info->is_view = true; + return list_info; +} + +unique_ptr ArrowListInfo::List(unique_ptr child, ArrowVariableSizeType size) { + D_ASSERT(size == ArrowVariableSizeType::SUPER_SIZE || size == ArrowVariableSizeType::NORMAL); + return unique_ptr(new ArrowListInfo(std::move(child), size)); +} + +ArrowVariableSizeType ArrowListInfo::GetSizeType() const { + return size_type; +} + +bool ArrowListInfo::IsView() const { + return is_view; +} + +ArrowType &ArrowListInfo::GetChild() const { + return *child; +} + +//===--------------------------------------------------------------------===// +// ArrowArrayInfo +//===--------------------------------------------------------------------===// + +ArrowArrayInfo::ArrowArrayInfo(unique_ptr child, idx_t fixed_size) + : ArrowTypeInfo(ArrowTypeInfoType::ARRAY), child(std::move(child)), fixed_size(fixed_size) { + D_ASSERT(fixed_size > 0); +} + +ArrowArrayInfo::~ArrowArrayInfo() { +} + +idx_t ArrowArrayInfo::FixedSize() const { + return fixed_size; +} + +ArrowType &ArrowArrayInfo::GetChild() const { + return *child; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index c1759ef8..b83bbf56 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -1,10 +1,13 @@ -#include "duckdb/function/table/arrow.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/types/hugeint.hpp" #include "duckdb/common/types/arrow_aux_data.hpp" +#include "duckdb/common/types/arrow_string_view_type.hpp" +#include "duckdb/common/types/hugeint.hpp" #include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" +#include "duckdb/function/table/arrow.hpp" + +#include "duckdb/common/bswap.hpp" namespace duckdb { @@ -35,7 +38,7 @@ static void ShiftRight(unsigned char *ar, int size, int shift) { } } -idx_t GetEffectiveOffset(ArrowArray &array, int64_t parent_offset, const ArrowScanLocalState &state, +idx_t GetEffectiveOffset(const ArrowArray &array, int64_t parent_offset, const ArrowScanLocalState &state, int64_t nested_offset = -1) { if (nested_offset != -1) { // The parent of this array is a list @@ -74,7 +77,7 @@ static void GetValidityMask(ValidityMask &mask, ArrowArray &array, const ArrowSc vector temp_nullmask(n_bitmask_bytes + 1); memcpy(temp_nullmask.data(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes + 1); ShiftRight(temp_nullmask.data(), NumericCast(n_bitmask_bytes + 1), - bit_offset % 8); //! why this has to be a right shift is a mystery to me + NumericCast(bit_offset % 8ull)); //! why this has to be a right shift is a mystery to me memcpy((void *)mask.GetData(), data_ptr_cast(temp_nullmask.data()), n_bitmask_bytes); } #else @@ -107,7 +110,7 @@ static void SetValidityMask(Vector &vector, ArrowArray &array, const ArrowScanLo GetValidityMask(mask, array, scan_state, size, parent_offset, nested_offset, add_null); } -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, +static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); @@ -117,44 +120,111 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); + const ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); -static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, - int64_t parent_offset) { - auto size_type = arrow_type.GetSizeType(); - idx_t list_size = 0; - auto &scan_state = array_state.state; +namespace { - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); +struct ArrowListOffsetData { + idx_t list_size = 0; idx_t start_offset = 0; +}; + +} // namespace + +template +static ArrowListOffsetData ConvertArrowListOffsetsTemplated(Vector &vector, ArrowArray &array, idx_t size, + idx_t effective_offset) { + ArrowListOffsetData result; + auto &start_offset = result.start_offset; + auto &list_size = result.list_size; + idx_t cur_offset = 0; - if (size_type == ArrowVariableSizeType::NORMAL) { - auto offsets = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - start_offset = offsets[0]; - auto list_data = FlatVector::GetData(vector); + auto offsets = ArrowBufferData(array, 1) + effective_offset; + start_offset = offsets[0]; + auto list_data = FlatVector::GetData(vector); + for (idx_t i = 0; i < size; i++) { + auto &le = list_data[i]; + le.offset = cur_offset; + le.length = offsets[i + 1] - offsets[i]; + cur_offset += le.length; + } + list_size = offsets[size]; + list_size -= start_offset; + return result; +} + +template +static ArrowListOffsetData ConvertArrowListViewOffsetsTemplated(Vector &vector, ArrowArray &array, idx_t size, + idx_t effective_offset) { + ArrowListOffsetData result; + auto &start_offset = result.start_offset; + auto &list_size = result.list_size; + + list_size = 0; + auto offsets = ArrowBufferData(array, 1) + effective_offset; + auto sizes = ArrowBufferData(array, 2) + effective_offset; + + // In ListArrays the offsets have to be sequential + // ListViewArrays do not have this same constraint + // for that reason we need to keep track of the lowest offset, so we can skip all the data that comes before it + // when we scan the child data + + auto lowest_offset = size ? offsets[0] : 0; + auto list_data = FlatVector::GetData(vector); + for (idx_t i = 0; i < size; i++) { + auto &le = list_data[i]; + le.offset = offsets[i]; + le.length = sizes[i]; + list_size += le.length; + if (sizes[i] != 0) { + lowest_offset = MinValue(lowest_offset, offsets[i]); + } + } + start_offset = lowest_offset; + if (start_offset) { + // We start scanning the child data at the 'start_offset' so we need to fix up the created list entries for (idx_t i = 0; i < size; i++) { auto &le = list_data[i]; - le.offset = cur_offset; - le.length = offsets[i + 1] - offsets[i]; - cur_offset += le.length; + le.offset = le.offset <= start_offset ? 0 : le.offset - start_offset; + } + } + return result; +} + +static ArrowListOffsetData ConvertArrowListOffsets(Vector &vector, ArrowArray &array, idx_t size, + const ArrowType &arrow_type, idx_t effective_offset) { + auto &list_info = arrow_type.GetTypeInfo(); + auto size_type = list_info.GetSizeType(); + if (list_info.IsView()) { + if (size_type == ArrowVariableSizeType::NORMAL) { + return ConvertArrowListViewOffsetsTemplated(vector, array, size, effective_offset); + } else { + D_ASSERT(size_type == ArrowVariableSizeType::SUPER_SIZE); + return ConvertArrowListViewOffsetsTemplated(vector, array, size, effective_offset); } - list_size = offsets[size]; } else { - auto offsets = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - start_offset = offsets[0]; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = cur_offset; - le.length = offsets[i + 1] - offsets[i]; - cur_offset += le.length; + if (size_type == ArrowVariableSizeType::NORMAL) { + return ConvertArrowListOffsetsTemplated(vector, array, size, effective_offset); + } else { + D_ASSERT(size_type == ArrowVariableSizeType::SUPER_SIZE); + return ConvertArrowListOffsetsTemplated(vector, array, size, effective_offset); } - list_size = offsets[size]; } - list_size -= start_offset; +} + +static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, + int64_t parent_offset) { + auto &scan_state = array_state.state; + + auto &list_info = arrow_type.GetTypeInfo(); + SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); + + auto effective_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto list_data = ConvertArrowListOffsets(vector, array, size, arrow_type, effective_offset); + auto &start_offset = list_data.start_offset; + auto &list_size = list_data.list_size; + ListVector::Reserve(vector, list_size); ListVector::SetListSize(vector, list_size); auto &child_vector = ListVector::GetEntry(vector); @@ -173,7 +243,8 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS } auto &child_state = array_state.GetChild(0); auto &child_array = *array.children[0]; - auto &child_type = arrow_type[0]; + auto &child_type = list_info.GetChild(); + if (list_size == 0 && start_offset == 0) { D_ASSERT(!child_array.dictionary); ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, -1); @@ -201,12 +272,12 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS } static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, + const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, int64_t parent_offset) { - D_ASSERT(arrow_type.GetSizeType() == ArrowVariableSizeType::FIXED_SIZE); + auto &array_info = arrow_type.GetTypeInfo(); auto &scan_state = array_state.state; - auto array_size = arrow_type.FixedSize(); + auto array_size = array_info.FixedSize(); auto child_count = array_size * size; auto child_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) * array_size; @@ -242,7 +313,7 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScan auto &child_state = array_state.GetChild(0); auto &child_array = *array.children[0]; - auto &child_type = arrow_type[0]; + auto &child_type = array_info.GetChild(); if (child_count == 0 && child_offset == 0) { D_ASSERT(!child_array.dictionary); ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, -1); @@ -259,10 +330,11 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScan static void ArrowToDuckDBBlob(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset, int64_t parent_offset) { - auto size_type = arrow_type.GetSizeType(); SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); + auto &string_info = arrow_type.GetTypeInfo(); + auto size_type = string_info.GetSizeType(); if (size_type == ArrowVariableSizeType::FIXED_SIZE) { - auto fixed_size = arrow_type.FixedSize(); + auto fixed_size = string_info.FixedSize(); //! Have to check validity mask before setting this up idx_t offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) * fixed_size; auto cdata = ArrowBufferData(array, 1); @@ -339,6 +411,35 @@ static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) } } +static void SetVectorStringView(Vector &vector, idx_t size, ArrowArray &array, idx_t current_pos) { + auto strings = FlatVector::GetData(vector); + auto arrow_string = ArrowBufferData(array, 1) + current_pos; + + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (FlatVector::IsNull(vector, row_idx)) { + continue; + } + auto length = UnsafeNumericCast(arrow_string[row_idx].Length()); + if (arrow_string[row_idx].IsInline()) { + // This string is inlined + // | Bytes 0-3 | Bytes 4-15 | + // |------------|---------------------------------------| + // | length | data (padded with 0) | + strings[row_idx] = string_t(arrow_string[row_idx].GetInlineData(), length); + } else { + // This string is not inlined, we have to check a different buffer and offsets + // | Bytes 0-3 | Bytes 4-7 | Bytes 8-11 | Bytes 12-15 | + // |------------|------------|------------|-------------| + // | length | prefix | buf. index | offset | + auto buffer_index = UnsafeNumericCast(arrow_string[row_idx].GetBufferIndex()); + int32_t offset = arrow_string[row_idx].GetOffset(); + D_ASSERT(array.n_buffers > 2 + buffer_index); + auto c_data = ArrowBufferData(array, 2 + buffer_index); + strings[row_idx] = string_t(&c_data[offset], length); + } + } +} + static void DirectConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, int64_t nested_offset, uint64_t parent_offset) { auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); @@ -353,17 +454,35 @@ static void TimeConversion(Vector &vector, ArrowArray &array, const ArrowScanLoc int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { auto tgt_ptr = FlatVector::GetData(vector); auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = (T *)array.buffers[1] + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto src_ptr = + static_cast(array.buffers[1]) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); for (idx_t row = 0; row < size; row++) { if (!validity_mask.RowIsValid(row)) { continue; } - if (!TryMultiplyOperator::Operation((int64_t)src_ptr[row], conversion, tgt_ptr[row].micros)) { + if (!TryMultiplyOperator::Operation(static_cast(src_ptr[row]), conversion, tgt_ptr[row].micros)) { throw ConversionException("Could not convert Time to Microsecond"); } } } +static void UUIDConversion(Vector &vector, const ArrowArray &array, const ArrowScanLocalState &scan_state, + int64_t nested_offset, int64_t parent_offset, idx_t size) { + auto tgt_ptr = FlatVector::GetData(vector); + auto &validity_mask = FlatVector::Validity(vector); + auto src_ptr = static_cast(array.buffers[1]) + + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + for (idx_t row = 0; row < size; row++) { + if (!validity_mask.RowIsValid(row)) { + continue; + } + tgt_ptr[row].lower = static_cast(BSwap(src_ptr[row].upper)); + // flip Upper MSD + tgt_ptr[row].upper = + static_cast(static_cast(BSwap(src_ptr[row].lower)) ^ (static_cast(1) << 63)); + } +} + static void TimestampTZConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { auto tgt_ptr = FlatVector::GetData(vector); @@ -470,7 +589,7 @@ static void FlattenRunEnds(Vector &result, ArrowRunEndEncodingState &run_end_enc idx_t index = 0; if (value_format.validity.AllValid()) { // None of the compressed values are NULL - for (; run < compressed_size; run++) { + for (; run < compressed_size; ++run) { auto run_end_index = run_end_format.sel->get_index(run); auto value_index = value_format.sel->get_index(run); auto &value = values_data[value_index]; @@ -488,13 +607,13 @@ static void FlattenRunEnds(Vector &result, ArrowRunEndEncodingState &run_end_enc if (index >= count) { if (logical_index + index >= run_end) { // The last run was completed, forward the run index - run++; + ++run; } break; } } } else { - for (; run < compressed_size; run++) { + for (; run < compressed_size; ++run) { auto run_end_index = run_end_format.sel->get_index(run); auto value_index = value_format.sel->get_index(run); auto run_end = static_cast(run_ends_data[run_end_index]); @@ -519,7 +638,7 @@ static void FlattenRunEnds(Vector &result, ArrowRunEndEncodingState &run_end_enc if (index >= count) { if (logical_index + index >= run_end) { // The last run was completed, forward the run index - run++; + ++run; } break; } @@ -584,7 +703,7 @@ static void FlattenRunEndsSwitch(Vector &result, ArrowRunEndEncodingState &run_e } } -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, +static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, uint64_t parent_offset) { // Scan the 'run_ends' array @@ -592,8 +711,9 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, ArrowArray &array, auto &run_ends_array = *array.children[0]; auto &values_array = *array.children[1]; - auto &run_ends_type = arrow_type[0]; - auto &values_type = arrow_type[1]; + auto &struct_info = arrow_type.GetTypeInfo(); + auto &run_ends_type = struct_info.GetChild(0); + auto &values_type = struct_info.GetChild(1); D_ASSERT(vector.GetType() == values_type.GetDuckType()); auto &scan_state = array_state.state; @@ -682,27 +802,45 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: { + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIME_TZ: { DirectConversion(vector, array, scan_state, nested_offset, parent_offset); break; } + case LogicalTypeId::UUID: + UUIDConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size); + break; case LogicalTypeId::VARCHAR: { - auto size_type = arrow_type.GetSizeType(); - auto cdata = ArrowBufferData(array, 2); - if (size_type == ArrowVariableSizeType::SUPER_SIZE) { + auto &string_info = arrow_type.GetTypeInfo(); + auto size_type = string_info.GetSizeType(); + switch (size_type) { + case ArrowVariableSizeType::SUPER_SIZE: { + auto cdata = ArrowBufferData(array, 2); auto offsets = ArrowBufferData(array, 1) + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); SetVectorString(vector, size, cdata, offsets); - } else { + break; + } + case ArrowVariableSizeType::NORMAL: + case ArrowVariableSizeType::FIXED_SIZE: { + auto cdata = ArrowBufferData(array, 2); auto offsets = ArrowBufferData(array, 1) + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); SetVectorString(vector, size, cdata, offsets); + break; + } + case ArrowVariableSizeType::VIEW: { + SetVectorStringView( + vector, size, array, + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset)); + break; + } } break; } case LogicalTypeId::DATE: { - - auto precision = arrow_type.GetDateTimeType(); + auto &datetime_info = arrow_type.GetTypeInfo(); + auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::DAYS: { DirectConversion(vector, array, scan_state, nested_offset, parent_offset); @@ -714,8 +852,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { - tgt_ptr[row] = date_t( - UnsafeNumericCast(int64_t(src_ptr[row]) / static_cast(1000 * 60 * 60 * 24))); + tgt_ptr[row] = date_t(UnsafeNumericCast(static_cast(src_ptr[row]) / + static_cast(1000 * 60 * 60 * 24))); } break; } @@ -725,7 +863,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case LogicalTypeId::TIME: { - auto precision = arrow_type.GetDateTimeType(); + auto &datetime_info = arrow_type.GetTypeInfo(); + auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, @@ -757,7 +896,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case LogicalTypeId::TIMESTAMP_TZ: { - auto precision = arrow_type.GetDateTimeType(); + auto &datetime_info = arrow_type.GetTypeInfo(); + auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, @@ -788,7 +928,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case LogicalTypeId::INTERVAL: { - auto precision = arrow_type.GetDateTimeType(); + auto &datetime_info = arrow_type.GetTypeInfo(); + auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, @@ -884,7 +1025,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca } break; } - case LogicalTypeId::BLOB: { + case LogicalTypeId::BLOB: + case LogicalTypeId::BIT: { ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset, NumericCast(parent_offset)); break; @@ -907,12 +1049,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca } case LogicalTypeId::STRUCT: { //! Fill the children + auto &struct_info = arrow_type.GetTypeInfo(); auto &child_entries = StructVector::GetEntries(vector); auto &struct_validity_mask = FlatVector::Validity(vector); for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { auto &child_entry = *child_entries[child_idx]; auto &child_array = *array.children[child_idx]; - auto &child_type = arrow_type[child_idx]; + auto &child_type = struct_info.GetChild(child_idx); auto &child_state = array_state.GetChild(child_idx); SetValidityMask(child_entry, child_array, scan_state, size, array.offset, nested_offset); @@ -951,13 +1094,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto members = UnionType::CopyMemberTypes(vector.GetType()); auto &validity_mask = FlatVector::Validity(vector); - + auto &union_info = arrow_type.GetTypeInfo(); duckdb::vector children; for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { Vector child(members[child_idx].second, size); auto &child_array = *array.children[child_idx]; auto &child_state = array_state.GetChild(child_idx); - auto &child_type = arrow_type[child_idx]; + auto &child_type = union_info.GetChild(child_idx); SetValidityMask(child, child_array, scan_state, size, NumericCast(parent_offset), nested_offset); auto array_physical_type = GetArrowArrayPhysicalType(child_type); @@ -1032,7 +1175,7 @@ static void SetMaskedSelectionVectorLoop(SelectionVector &sel, data_ptr_t indice } } -static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, LogicalType &logical_type, idx_t size, +static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, const LogicalType &logical_type, idx_t size, ValidityMask *mask = nullptr, idx_t last_element_pos = 0) { sel.Initialize(size); @@ -1121,7 +1264,7 @@ static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, Logic } } -static bool CanContainNull(ArrowArray &array, ValidityMask *parent_mask) { +static bool CanContainNull(const ArrowArray &array, const ValidityMask *parent_mask) { if (array.null_count > 0) { return true; } @@ -1133,7 +1276,7 @@ static bool CanContainNull(ArrowArray &array, ValidityMask *parent_mask) { static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset, - ValidityMask *parent_mask, uint64_t parent_offset) { + const ValidityMask *parent_mask, uint64_t parent_offset) { D_ASSERT(arrow_type.HasDictionary()); auto &scan_state = array_state.state; const bool has_nulls = CanContainNull(array, parent_mask); diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp index 05a1187a..b2c16a67 100644 --- a/src/duckdb/src/function/table/copy_csv.cpp +++ b/src/duckdb/src/function/table/copy_csv.cpp @@ -11,15 +11,13 @@ #include "duckdb/function/copy_function.hpp" #include "duckdb/function/scalar/string_functions.hpp" #include "duckdb/function/table/read_csv.hpp" -#include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" #include "duckdb/parser/expression/cast_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/parsed_data/copy_info.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/execution/column_binding_resolver.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" + #include namespace duckdb { @@ -70,6 +68,14 @@ void BaseCSVData::Finalize() { options.dialect_options.state_machine_options.escape.GetValue(), "QUOTE", "ESCAPE"); } + // delimiter and quote must not be substrings of each other + AreOptionsEqual(options.dialect_options.state_machine_options.comment.GetValue(), + options.dialect_options.state_machine_options.quote.GetValue(), "COMMENT", "QUOTE"); + + // delimiter and quote must not be substrings of each other + AreOptionsEqual(options.dialect_options.state_machine_options.comment.GetValue(), + options.dialect_options.state_machine_options.delimiter.GetValue(), "COMMENT", "DELIMITER"); + // null string and delimiter must not be substrings of each other for (auto &null_str : options.null_str) { if (!null_str.empty()) { @@ -174,6 +180,21 @@ static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctio } bind_data->Finalize(); + switch (bind_data->options.compression) { + case FileCompressionType::GZIP: + if (!IsFileCompressed(input.file_extension, FileCompressionType::GZIP)) { + input.file_extension += CompressionExtensionFromType(FileCompressionType::GZIP); + } + break; + case FileCompressionType::ZSTD: + if (!IsFileCompressed(input.file_extension, FileCompressionType::ZSTD)) { + input.file_extension += CompressionExtensionFromType(FileCompressionType::ZSTD); + } + break; + default: + break; + } + auto expressions = CreateCastExpressions(*bind_data, context, names, sql_types); bind_data->cast_expressions = std::move(expressions); @@ -225,14 +246,14 @@ static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &in options.file_path = bind_data->files[0]; options.name_list = expected_names; options.sql_type_list = expected_types; + options.columns_set = true; for (idx_t i = 0; i < expected_types.size(); i++) { options.sql_types_per_column[expected_names[i]] = i; } if (options.auto_detect) { auto buffer_manager = make_shared_ptr(context, options, bind_data->files[0], 0); - CSVSniffer sniffer(options, buffer_manager, CSVStateMachineCache::Get(context), - {&expected_types, &expected_names}); + CSVSniffer sniffer(options, buffer_manager, CSVStateMachineCache::Get(context)); sniffer.SniffCSV(); } bind_data->FinalizeRead(context); @@ -432,11 +453,6 @@ static unique_ptr WriteCSVInitializeGlobal(ClientContext &co return std::move(global_data); } -idx_t WriteCSVFileSize(GlobalFunctionData &gstate) { - auto &global_state = gstate.Cast(); - return global_state.FileSize(); -} - static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk, MemoryStream &writer, DataChunk &input, bool &written_anything, ExpressionExecutor &executor) { @@ -592,6 +608,18 @@ void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalF writer.Rewind(); } +//===--------------------------------------------------------------------===// +// File rotation +//===--------------------------------------------------------------------===// +bool WriteCSVRotateFiles(FunctionData &, const optional_idx &file_size_bytes) { + return file_size_bytes.IsValid(); +} + +bool WriteCSVRotateNextFile(GlobalFunctionData &gstate, FunctionData &, const optional_idx &file_size_bytes) { + auto &global_state = gstate.Cast(); + return global_state.FileSize() > file_size_bytes.GetIndex(); +} + void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) { CopyFunction info("csv"); info.copy_to_bind = WriteCSVBind; @@ -603,7 +631,8 @@ void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) { info.execution_mode = WriteCSVExecutionMode; info.prepare_batch = WriteCSVPrepareBatch; info.flush_batch = WriteCSVFlushBatch; - info.file_size_bytes = WriteCSVFileSize; + info.rotate_files = WriteCSVRotateFiles; + info.rotate_next_file = WriteCSVRotateNextFile; info.copy_from_bind = ReadCSVBind; info.copy_from_function = ReadCSVTableFunction::GetFunction(); diff --git a/src/duckdb/src/function/table/query_function.cpp b/src/duckdb/src/function/table/query_function.cpp new file mode 100644 index 00000000..d000c03a --- /dev/null +++ b/src/duckdb/src/function/table/query_function.cpp @@ -0,0 +1,80 @@ +#include "duckdb/parser/parser.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/function/table/range.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" + +namespace duckdb { + +static unique_ptr ParseSubquery(const string &query, const ParserOptions &options, const string &err_msg) { + Parser parser(options); + parser.ParseQuery(query); + if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { + throw ParserException(err_msg); + } + auto select_stmt = unique_ptr_cast(std::move(parser.statements[0])); + return duckdb::make_uniq(std::move(select_stmt)); +} + +static void UnionTablesQuery(TableFunctionBindInput &input, string &query) { + for (auto &input_val : input.inputs) { + if (input_val.IsNull()) { + throw BinderException("Cannot use NULL as function argument"); + } + } + string by_name = (input.inputs.size() == 2 && + (input.inputs[1].type().id() == LogicalTypeId::BOOLEAN && input.inputs[1].GetValue())) + ? "BY NAME " + : ""; // 'by_name' variable defaults to false + if (input.inputs[0].type().id() == LogicalTypeId::VARCHAR) { + query += "FROM " + KeywordHelper::WriteOptionallyQuoted(input.inputs[0].ToString()); + } else if (input.inputs[0].type() == LogicalType::LIST(LogicalType::VARCHAR)) { + string union_all_clause = " UNION ALL " + by_name + "FROM "; + const auto &children = ListValue::GetChildren(input.inputs[0]); + if (children.empty()) { + throw InvalidInputException("Input list is empty"); + } + + query += "FROM " + KeywordHelper::WriteOptionallyQuoted(children[0].ToString()); + for (size_t i = 1; i < children.size(); ++i) { + auto child = children[i].ToString(); + query += union_all_clause + KeywordHelper::WriteOptionallyQuoted(child); + } + } else { + throw InvalidInputException("Expected a table or a list with tables as input"); + } +} + +static unique_ptr QueryBindReplace(ClientContext &context, TableFunctionBindInput &input) { + auto query = input.inputs[0].ToString(); + auto subquery_ref = ParseSubquery(query, context.GetParserOptions(), "Expected a single SELECT statement"); + return std::move(subquery_ref); +} + +static unique_ptr TableBindReplace(ClientContext &context, TableFunctionBindInput &input) { + string query; + UnionTablesQuery(input, query); + auto subquery_ref = + ParseSubquery(query, context.GetParserOptions(), "Expected a table or a list with tables as input"); + return std::move(subquery_ref); +} + +void QueryTableFunction::RegisterFunction(BuiltinFunctions &set) { + TableFunction query("query", {LogicalType::VARCHAR}, nullptr, nullptr); + query.bind_replace = QueryBindReplace; + set.AddFunction(query); + + TableFunctionSet query_table("query_table"); + TableFunction query_table_function({LogicalType::VARCHAR}, nullptr, nullptr); + query_table_function.bind_replace = TableBindReplace; + query_table.AddFunction(query_table_function); + + query_table_function.arguments = {LogicalType::LIST(LogicalType::VARCHAR)}; + query_table.AddFunction(query_table_function); + // add by_name option + query_table_function.arguments.emplace_back(LogicalType::BOOLEAN); + query_table.AddFunction(query_table_function); + set.AddFunction(query_table); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/range.cpp b/src/duckdb/src/function/table/range.cpp index 884a51b7..17bcda81 100644 --- a/src/duckdb/src/function/table/range.cpp +++ b/src/duckdb/src/function/table/range.cpp @@ -11,129 +11,196 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Range (integers) //===--------------------------------------------------------------------===// +static void GetParameters(int64_t values[], idx_t value_count, hugeint_t &start, hugeint_t &end, hugeint_t &increment) { + if (value_count < 2) { + // single argument: only the end is specified + start = 0; + end = values[0]; + } else { + // two arguments: first two arguments are start and end + start = values[0]; + end = values[1]; + } + if (value_count < 3) { + increment = 1; + } else { + increment = values[2]; + } +} + struct RangeFunctionBindData : public TableFunctionData { + explicit RangeFunctionBindData(const vector &inputs) : cardinality(0) { + int64_t values[3]; + for (idx_t i = 0; i < inputs.size(); i++) { + if (inputs[i].IsNull()) { + return; + } + values[i] = inputs[i].GetValue(); + } + hugeint_t start; + hugeint_t end; + hugeint_t increment; + GetParameters(values, inputs.size(), start, end, increment); + cardinality = Hugeint::Cast((end - start) / increment); + } + + idx_t cardinality; +}; + +template +static unique_ptr RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return_types.emplace_back(LogicalType::BIGINT); + if (GENERATE_SERIES) { + names.emplace_back("generate_series"); + } else { + names.emplace_back("range"); + } + if (input.inputs.empty() || input.inputs.size() > 3) { + return nullptr; + } + return make_uniq(input.inputs); +} + +struct RangeFunctionLocalState : public LocalTableFunctionState { + RangeFunctionLocalState() { + } + + bool initialized_row = false; + idx_t current_input_row = 0; + idx_t current_idx = 0; + hugeint_t start; hugeint_t end; hugeint_t increment; - -public: - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return other.start == start && other.end == end && other.increment == increment; - } }; +static unique_ptr RangeFunctionLocalInit(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state) { + return make_uniq(); +} + template -static void GenerateRangeParameters(const vector &inputs, RangeFunctionBindData &result) { - for (auto &input : inputs) { - if (input.IsNull()) { +static void GenerateRangeParameters(DataChunk &input, idx_t row_id, RangeFunctionLocalState &result) { + input.Flatten(); + for (idx_t c = 0; c < input.ColumnCount(); c++) { + if (FlatVector::IsNull(input.data[c], row_id)) { result.start = GENERATE_SERIES ? 1 : 0; result.end = 0; result.increment = 1; return; } } - if (inputs.size() < 2) { - // single argument: only the end is specified - result.start = 0; - result.end = inputs[0].GetValue(); - } else { - // two arguments: first two arguments are start and end - result.start = inputs[0].GetValue(); - result.end = inputs[1].GetValue(); - } - if (inputs.size() < 3) { - result.increment = 1; - } else { - result.increment = inputs[2].GetValue(); + int64_t values[3]; + for (idx_t c = 0; c < input.ColumnCount(); c++) { + if (c >= 3) { + throw InternalException("Unsupported parameter count for range function"); + } + values[c] = FlatVector::GetValue(input.data[c], row_id); } + GetParameters(values, input.ColumnCount(), result.start, result.end, result.increment); if (result.increment == 0) { throw BinderException("interval cannot be 0!"); } if (result.start > result.end && result.increment > 0) { throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series"); - } else if (result.start < result.end && result.increment < 0) { + } + if (result.start < result.end && result.increment < 0) { throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series"); } -} - -template -static unique_ptr RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - auto &inputs = input.inputs; - GenerateRangeParameters(inputs, *result); - - return_types.emplace_back(LogicalType::BIGINT); if (GENERATE_SERIES) { // generate_series has inclusive bounds on the RHS - if (result->increment < 0) { - result->end = result->end - 1; + if (result.increment < 0) { + result.end = result.end - 1; } else { - result->end = result->end + 1; + result.end = result.end + 1; } - names.emplace_back("generate_series"); - } else { - names.emplace_back("range"); } - return std::move(result); } -struct RangeFunctionState : public GlobalTableFunctionState { - RangeFunctionState() : current_idx(0) { - } - - int64_t current_idx; -}; - -static unique_ptr RangeFunctionInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - auto increment = bind_data.increment; - auto end = bind_data.end; - hugeint_t current_value = bind_data.start + increment * state.current_idx; - int64_t current_value_i64; - if (!Hugeint::TryCast(current_value, current_value_i64)) { - return; +template +static OperatorResultType RangeFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, + DataChunk &output) { + auto &state = data_p.local_state->Cast(); + while (true) { + if (!state.initialized_row) { + // initialize for the current input row + if (state.current_input_row >= input.size()) { + // ran out of rows + state.current_input_row = 0; + state.initialized_row = false; + return OperatorResultType::NEED_MORE_INPUT; + } + GenerateRangeParameters(input, state.current_input_row, state); + state.initialized_row = true; + state.current_idx = 0; + } + auto increment = state.increment; + auto end = state.end; + hugeint_t current_value = state.start + increment * UnsafeNumericCast(state.current_idx); + int64_t current_value_i64; + if (!Hugeint::TryCast(current_value, current_value_i64)) { + // move to next row + state.current_input_row++; + state.initialized_row = false; + continue; + } + int64_t offset = increment < 0 ? 1 : -1; + idx_t remaining = MinValue( + Hugeint::Cast((end - current_value + (increment + offset)) / increment), STANDARD_VECTOR_SIZE); + // set the result vector as a sequence vector + output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); + // increment the index pointer by the remaining count + state.current_idx += remaining; + output.SetCardinality(remaining); + if (remaining == 0) { + // move to next row + state.current_input_row++; + state.initialized_row = false; + continue; + } + return OperatorResultType::HAVE_MORE_OUTPUT; } - int64_t offset = increment < 0 ? 1 : -1; - idx_t remaining = MinValue(Hugeint::Cast((end - current_value + (increment + offset)) / increment), - STANDARD_VECTOR_SIZE); - // set the result vector as a sequence vector - output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); - // increment the index pointer by the remaining count - state.current_idx += remaining; - output.SetCardinality(remaining); } unique_ptr RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { + if (!bind_data_p) { + return nullptr; + } auto &bind_data = bind_data_p->Cast(); - idx_t cardinality = Hugeint::Cast((bind_data.end - bind_data.start) / bind_data.increment); - return make_uniq(cardinality, cardinality); + return make_uniq(bind_data.cardinality, bind_data.cardinality); } //===--------------------------------------------------------------------===// // Range (timestamp) //===--------------------------------------------------------------------===// -struct RangeDateTimeBindData : public TableFunctionData { +template +static unique_ptr RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return_types.push_back(LogicalType::TIMESTAMP); + if (GENERATE_SERIES) { + names.emplace_back("generate_series"); + } else { + names.emplace_back("range"); + } + return nullptr; +} + +struct RangeDateTimeLocalState : public LocalTableFunctionState { + RangeDateTimeLocalState() { + } + + bool initialized_row = false; + idx_t current_input_row = 0; + timestamp_t current_state; + timestamp_t start; timestamp_t end; interval_t increment; bool inclusive_bound; bool greater_than_check; -public: - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return other.start == start && other.end == end && other.increment == increment && - other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check; - } - bool Finished(timestamp_t current_value) const { if (greater_than_check) { if (inclusive_bound) { @@ -152,98 +219,105 @@ struct RangeDateTimeBindData : public TableFunctionData { }; template -static unique_ptr RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - auto &inputs = input.inputs; - D_ASSERT(inputs.size() == 3); - for (idx_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].IsNull()) { - throw BinderException("RANGE with NULL argument is not supported"); +static void GenerateRangeDateTimeParameters(DataChunk &input, idx_t row_id, RangeDateTimeLocalState &result) { + input.Flatten(); + + for (idx_t c = 0; c < input.ColumnCount(); c++) { + if (FlatVector::IsNull(input.data[c], row_id)) { + result.start = timestamp_t(0); + result.end = timestamp_t(0); + result.increment = interval_t(); + result.greater_than_check = true; + result.inclusive_bound = false; + return; } } - result->start = inputs[0].GetValue(); - result->end = inputs[1].GetValue(); - result->increment = inputs[2].GetValue(); + + result.start = FlatVector::GetValue(input.data[0], row_id); + result.end = FlatVector::GetValue(input.data[1], row_id); + result.increment = FlatVector::GetValue(input.data[2], row_id); // Infinities either cause errors or infinite loops, so just ban them - if (!Timestamp::IsFinite(result->start) || !Timestamp::IsFinite(result->end)) { + if (!Timestamp::IsFinite(result.start) || !Timestamp::IsFinite(result.end)) { throw BinderException("RANGE with infinite bounds is not supported"); } - if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { + if (result.increment.months == 0 && result.increment.days == 0 && result.increment.micros == 0) { throw BinderException("interval cannot be 0!"); } // all elements should point in the same direction - if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { - if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { + if (result.increment.months > 0 || result.increment.days > 0 || result.increment.micros > 0) { + if (result.increment.months < 0 || result.increment.days < 0 || result.increment.micros < 0) { throw BinderException("RANGE with composite interval that has mixed signs is not supported"); } - result->greater_than_check = true; - if (result->start > result->end) { + result.greater_than_check = true; + if (result.start > result.end) { throw BinderException( "start is bigger than end, but increment is positive: cannot generate infinite series"); } } else { - result->greater_than_check = false; - if (result->start < result->end) { + result.greater_than_check = false; + if (result.start < result.end) { throw BinderException( "start is smaller than end, but increment is negative: cannot generate infinite series"); } } - return_types.push_back(inputs[0].type()); - if (GENERATE_SERIES) { - // generate_series has inclusive bounds on the RHS - result->inclusive_bound = true; - names.emplace_back("generate_series"); - } else { - result->inclusive_bound = false; - names.emplace_back("range"); - } - return std::move(result); + result.inclusive_bound = GENERATE_SERIES; } -struct RangeDateTimeState : public GlobalTableFunctionState { - explicit RangeDateTimeState(timestamp_t start_p) : current_state(start_p) { - } - - timestamp_t current_state; - bool finished = false; -}; - -static unique_ptr RangeDateTimeInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - return make_uniq(bind_data.start); +static unique_ptr RangeDateTimeLocalInit(ExecutionContext &context, + TableFunctionInitInput &input, + GlobalTableFunctionState *global_state) { + return make_uniq(); } -static void RangeDateTimeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - if (state.finished) { - return; - } - - idx_t size = 0; - auto data = FlatVector::GetData(output.data[0]); +template +static OperatorResultType RangeDateTimeFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, + DataChunk &output) { + auto &state = data_p.local_state->Cast(); while (true) { - if (bind_data.Finished(state.current_state)) { - state.finished = true; - break; + if (!state.initialized_row) { + // initialize for the current input row + if (state.current_input_row >= input.size()) { + // ran out of rows + state.current_input_row = 0; + state.initialized_row = false; + return OperatorResultType::NEED_MORE_INPUT; + } + GenerateRangeDateTimeParameters(input, state.current_input_row, state); + state.initialized_row = true; + state.current_state = state.start; + } + idx_t size = 0; + auto data = FlatVector::GetData(output.data[0]); + while (true) { + if (state.Finished(state.current_state)) { + break; + } + if (size >= STANDARD_VECTOR_SIZE) { + break; + } + data[size++] = state.current_state; + state.current_state = + AddOperator::Operation(state.current_state, state.increment); } - if (size >= STANDARD_VECTOR_SIZE) { - break; + if (size == 0) { + // move to next row + state.current_input_row++; + state.initialized_row = false; + continue; } - data[size++] = state.current_state; - state.current_state = - AddOperator::Operation(state.current_state, bind_data.increment); + output.SetCardinality(size); + return OperatorResultType::HAVE_MORE_OUTPUT; } - output.SetCardinality(size); } void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { TableFunctionSet range("range"); - TableFunction range_function({LogicalType::BIGINT}, RangeFunction, RangeFunctionBind, RangeFunctionInit); + TableFunction range_function({LogicalType::BIGINT}, nullptr, RangeFunctionBind, nullptr, + RangeFunctionLocalInit); + range_function.in_out_function = RangeFunction; range_function.cardinality = RangeCardinality; // single argument range: (end) - implicit start = 0 and increment = 1 @@ -254,20 +328,25 @@ void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { // three arguments range: (start, end, increment) range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; range.AddFunction(range_function); - range.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); + TableFunction range_in_out({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, nullptr, + RangeDateTimeBind, nullptr, RangeDateTimeLocalInit); + range_in_out.in_out_function = RangeDateTimeFunction; + range.AddFunction(range_in_out); set.AddFunction(range); // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS TableFunctionSet generate_series("generate_series"); range_function.bind = RangeFunctionBind; + range_function.in_out_function = RangeFunction; range_function.arguments = {LogicalType::BIGINT}; generate_series.AddFunction(range_function); range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; generate_series.AddFunction(range_function); range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; generate_series.AddFunction(range_function); - generate_series.AddFunction(TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - RangeDateTimeFunction, RangeDateTimeBind, RangeDateTimeInit)); + TableFunction generate_series_in_out({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + nullptr, RangeDateTimeBind, nullptr, RangeDateTimeLocalInit); + generate_series_in_out.in_out_function = RangeDateTimeFunction; + generate_series.AddFunction(generate_series_in_out); set.AddFunction(generate_series); } @@ -282,6 +361,7 @@ void BuiltinFunctions::RegisterTableFunctions() { CSVSnifferFunction::RegisterFunction(*this); ReadBlobFunction::RegisterFunction(*this); ReadTextFunction::RegisterFunction(*this); + QueryTableFunction::RegisterFunction(*this); } } // namespace duckdb diff --git a/src/duckdb/src/function/table/read_csv.cpp b/src/duckdb/src/function/table/read_csv.cpp index 58cbfdd7..58948af7 100644 --- a/src/duckdb/src/function/table/read_csv.cpp +++ b/src/duckdb/src/function/table/read_csv.cpp @@ -51,7 +51,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, TableFunctio auto multi_file_reader = MultiFileReader::Create(input.table_function); auto multi_file_list = multi_file_reader->CreateFileList(context, input.inputs[0]); - options.FromNamedParameters(input.named_parameters, context, return_types, names); + options.FromNamedParameters(input.named_parameters, context); if (options.rejects_table_name.IsSetByUser() && !options.store_rejects.GetValue() && options.store_rejects.IsSetByUser()) { throw BinderException("REJECTS_TABLE option is only supported when store_rejects is not manually set to false"); @@ -82,16 +82,20 @@ static unique_ptr ReadCSVBind(ClientContext &context, TableFunctio options.file_options.AutoDetectHivePartitioning(*multi_file_list, context); - if (!options.auto_detect && return_types.empty()) { - throw BinderException("read_csv requires columns to be specified through the 'columns' option. Use " - "read_csv_auto or set read_csv(..., " - "AUTO_DETECT=TRUE) to automatically guess columns."); + if (!options.auto_detect) { + if (!options.columns_set) { + throw BinderException("read_csv requires columns to be specified through the 'columns' option. Use " + "read_csv_auto or set read_csv(..., " + "AUTO_DETECT=TRUE) to automatically guess columns."); + } else { + names = options.name_list; + return_types = options.sql_type_list; + } } if (options.auto_detect && !options.file_options.union_by_name) { options.file_path = multi_file_list->GetFirstFile(); result->buffer_manager = make_shared_ptr(context, options, options.file_path, 0); - CSVSniffer sniffer(options, result->buffer_manager, CSVStateMachineCache::Get(context), - {&return_types, &names}); + CSVSniffer sniffer(options, result->buffer_manager, CSVStateMachineCache::Get(context)); auto sniffer_result = sniffer.SniffCSV(); if (names.empty()) { names = sniffer_result.names; @@ -107,8 +111,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, TableFunctio result->reader_bind = multi_file_reader->BindUnionReader(context, return_types, names, *multi_file_list, *result, options); if (result->union_readers.size() > 1) { - result->column_info.emplace_back(result->initial_reader->names, result->initial_reader->types); - for (idx_t i = 1; i < result->union_readers.size(); i++) { + for (idx_t i = 0; i < result->union_readers.size(); i++) { result->column_info.emplace_back(result->union_readers[i]->names, result->union_readers[i]->types); } } @@ -202,6 +205,10 @@ unique_ptr ReadCSVInitLocal(ExecutionContext &context, return nullptr; } auto &global_state = global_state_p->Cast(); + if (global_state.IsDone()) { + // nothing to do + return nullptr; + } auto csv_scanner = global_state.Next(nullptr); if (!csv_scanner) { global_state.DecrementThread(); @@ -215,6 +222,9 @@ static void ReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, return; } auto &csv_global_state = data_p.global_state->Cast(); + if (!data_p.local_state) { + return; + } auto &csv_local_state = data_p.local_state->Cast(); if (!csv_local_state.csv_reader) { @@ -282,6 +292,7 @@ void ReadCSVTableFunction::ReadCSVAddNamedParameters(TableFunction &table_functi table_function.named_parameters["names"] = LogicalType::LIST(LogicalType::VARCHAR); table_function.named_parameters["column_names"] = LogicalType::LIST(LogicalType::VARCHAR); table_function.named_parameters["parallel"] = LogicalType::BOOLEAN; + table_function.named_parameters["comment"] = LogicalType::VARCHAR; MultiFileReader::AddParameters(table_function); } @@ -300,8 +311,9 @@ void CSVComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionD vector> &filters) { auto &data = bind_data_p->Cast(); SimpleMultiFileList file_list(data.files); + MultiFilePushdownInfo info(get); auto filtered_list = - MultiFileReader().ComplexFilterPushdown(context, file_list, data.options.file_options, get, filters); + MultiFileReader().ComplexFilterPushdown(context, file_list, data.options.file_options, info, filters); if (filtered_list) { data.files = filtered_list->GetAllFiles(); MultiFileReader::PruneReaders(data, file_list); @@ -375,12 +387,12 @@ void ReadCSVTableFunction::RegisterFunction(BuiltinFunctions &set) { unique_ptr ReadCSVReplacement(ClientContext &context, ReplacementScanInput &input, optional_ptr data) { - auto &table_name = input.table_name; + auto table_name = ReplacementScan::GetFullPath(input); auto lower_name = StringUtil::Lower(table_name); // remove any compression - if (StringUtil::EndsWith(lower_name, ".gz")) { + if (StringUtil::EndsWith(lower_name, CompressionExtensionFromType(FileCompressionType::GZIP))) { lower_name = lower_name.substr(0, lower_name.size() - 3); - } else if (StringUtil::EndsWith(lower_name, ".zst")) { + } else if (StringUtil::EndsWith(lower_name, CompressionExtensionFromType(FileCompressionType::ZSTD))) { if (!Catalog::TryAutoLoad(context, "parquet")) { throw MissingExtensionException("parquet extension is required for reading zst compressed file"); } diff --git a/src/duckdb/src/function/table/sniff_csv.cpp b/src/duckdb/src/function/table/sniff_csv.cpp index d395e9dd..11e5cca8 100644 --- a/src/duckdb/src/function/table/sniff_csv.cpp +++ b/src/duckdb/src/function/table/sniff_csv.cpp @@ -36,6 +36,10 @@ static unique_ptr CSVSniffInitGlobal(ClientContext &co static unique_ptr CSVSniffBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { auto result = make_uniq(); + auto &config = DBConfig::GetConfig(context); + if (!config.options.enable_external_access) { + throw PermissionException("sniff_csv is disabled through configuration"); + } result->path = input.inputs[0].ToString(); auto it = input.named_parameters.find("auto_detect"); if (it != input.named_parameters.end()) { @@ -45,7 +49,7 @@ static unique_ptr CSVSniffBind(ClientContext &context, TableFuncti // otherwise remove it input.named_parameters.erase("auto_detect"); } - result->options.FromNamedParameters(input.named_parameters, context, result->return_types_csv, result->names_csv); + result->options.FromNamedParameters(input.named_parameters, context); // We want to return the whole CSV Configuration // 1. Delimiter return_types.emplace_back(LogicalType::VARCHAR); @@ -59,27 +63,30 @@ static unique_ptr CSVSniffBind(ClientContext &context, TableFuncti // 4. NewLine Delimiter return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("NewLineDelimiter"); - // 5. Skip Rows + // 5. Comment + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("Comment"); + // 6. Skip Rows return_types.emplace_back(LogicalType::UINTEGER); names.emplace_back("SkipRows"); - // 6. Has Header + // 7. Has Header return_types.emplace_back(LogicalType::BOOLEAN); names.emplace_back("HasHeader"); - // 7. List> + // 8. List> child_list_t struct_children {{"name", LogicalType::VARCHAR}, {"type", LogicalType::VARCHAR}}; auto list_child = LogicalType::STRUCT(struct_children); return_types.emplace_back(LogicalType::LIST(list_child)); names.emplace_back("Columns"); - // 8. Date Format + // 9. Date Format return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("DateFormat"); - // 9. Timestamp Format + // 10. Timestamp Format return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("TimestampFormat"); - // 10. CSV read function with all the options used + // 11. CSV read function with all the options used return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("UserArguments"); - // 11. CSV read function with all the options used + // 12. CSV read function with all the options used return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("Prompt"); return std::move(result); @@ -103,18 +110,20 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, const CSVSniffFunctionData &data = data_p.bind_data->Cast(); auto &fs = duckdb::FileSystem::GetFileSystem(context); - if (data.path.rfind("http://", 0) != 0 && data.path.rfind("https://", 0) != 0 && fs.HasGlob(data.path)) { - throw NotImplementedException("sniff_csv does not operate on globs yet"); + auto paths = fs.GlobFiles(data.path, context, FileGlobOptions::DISALLOW_EMPTY); + if (paths.size() > 1) { + throw NotImplementedException("sniff_csv does not operate on more than one file yet"); } // We must run the sniffer. auto sniffer_options = data.options; - sniffer_options.file_path = data.path; + sniffer_options.file_path = paths[0]; auto buffer_manager = make_shared_ptr(context, sniffer_options, sniffer_options.file_path, 0); if (sniffer_options.name_list.empty()) { sniffer_options.name_list = data.names_csv; } + if (sniffer_options.sql_type_list.empty()) { sniffer_options.sql_type_list = data.return_types_csv; } @@ -137,12 +146,15 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, // 4. NewLine Delimiter auto new_line_identifier = sniffer_options.NewLineIdentifierToString(); output.SetValue(3, 0, new_line_identifier); - // 5. Skip Rows - output.SetValue(4, 0, Value::UINTEGER(NumericCast(sniffer_options.dialect_options.skip_rows.GetValue()))); - // 6. Has Header + // 5. Comment + str_opt = sniffer_options.dialect_options.state_machine_options.comment.GetValue(); + output.SetValue(4, 0, str_opt); + // 6. Skip Rows + output.SetValue(5, 0, Value::UINTEGER(NumericCast(sniffer_options.dialect_options.skip_rows.GetValue()))); + // 7. Has Header auto has_header = Value::BOOLEAN(sniffer_options.dialect_options.header.GetValue()).ToString(); - output.SetValue(5, 0, has_header); - // 7. List> {'col1': 'INTEGER', 'col2': 'VARCHAR'} + output.SetValue(6, 0, has_header); + // 8. List> {'col1': 'INTEGER', 'col2': 'VARCHAR'} vector values; std::ostringstream columns; columns << "{"; @@ -156,45 +168,45 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, } } columns << "}"; - output.SetValue(6, 0, Value::LIST(values)); - // 8. Date Format + output.SetValue(7, 0, Value::LIST(values)); + // 9. Date Format auto date_format = sniffer_options.dialect_options.date_format[LogicalType::DATE].GetValue(); if (!date_format.Empty()) { - output.SetValue(7, 0, date_format.format_specifier); + output.SetValue(8, 0, date_format.format_specifier); } else { bool has_date = false; for (auto &c_type : sniffer_result.return_types) { // Must be ISO 8601 if (c_type.id() == LogicalTypeId::DATE) { - output.SetValue(7, 0, Value("%Y-%m-%d")); + output.SetValue(8, 0, Value("%Y-%m-%d")); has_date = true; } } if (!has_date) { - output.SetValue(7, 0, Value(nullptr)); + output.SetValue(8, 0, Value(nullptr)); } } - // 9. Timestamp Format + // 10. Timestamp Format auto timestamp_format = sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue(); if (!timestamp_format.Empty()) { - output.SetValue(8, 0, timestamp_format.format_specifier); + output.SetValue(9, 0, timestamp_format.format_specifier); } else { - output.SetValue(8, 0, Value(nullptr)); + output.SetValue(9, 0, Value(nullptr)); } - // 10. The Extra User Arguments + // 11. The Extra User Arguments if (data.options.user_defined_parameters.empty()) { - output.SetValue(9, 0, Value()); + output.SetValue(10, 0, Value()); } else { - output.SetValue(9, 0, Value(data.options.user_defined_parameters)); + output.SetValue(10, 0, Value(data.options.user_defined_parameters)); } - // 11. csv_read string + // 12. csv_read string std::ostringstream csv_read; // Base, Path and auto_detect=false - csv_read << "FROM read_csv('" << data.path << "'" << separator << "auto_detect=false" << separator; + csv_read << "FROM read_csv('" << paths[0] << "'" << separator << "auto_detect=false" << separator; // 10.1. Delimiter if (!sniffer_options.dialect_options.state_machine_options.delimiter.IsSetByUser()) { csv_read << "delim=" @@ -224,13 +236,21 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, if (!sniffer_options.dialect_options.skip_rows.IsSetByUser()) { csv_read << "skip=" << sniffer_options.dialect_options.skip_rows.GetValue() << separator; } - // 11.6. Has Header + + // 11.6. Comment + if (!sniffer_options.dialect_options.state_machine_options.comment.IsSetByUser()) { + csv_read << "comment=" + << "'" << FormatOptions(sniffer_options.dialect_options.state_machine_options.comment.GetValue()) + << "'" << separator; + } + + // 11.7. Has Header if (!sniffer_options.dialect_options.header.IsSetByUser()) { csv_read << "header=" << has_header << separator; } - // 11.7. column={'col1': 'INTEGER', 'col2': 'VARCHAR'} + // 11.8. column={'col1': 'INTEGER', 'col2': 'VARCHAR'} csv_read << "columns=" << columns.str(); - // 11.8. Date Format + // 11.9. Date Format if (!sniffer_options.dialect_options.date_format[LogicalType::DATE].IsSetByUser()) { if (!sniffer_options.dialect_options.date_format[LogicalType::DATE].GetValue().format_specifier.empty()) { csv_read << separator << "dateformat=" @@ -248,7 +268,7 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, } } } - // 11.9. Timestamp Format + // 11.10. Timestamp Format if (!sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].IsSetByUser()) { if (!sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue().format_specifier.empty()) { csv_read << separator << "timestampformat=" @@ -257,12 +277,12 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, << "'"; } } - // 11.10 User Arguments + // 11.11 User Arguments if (!data.options.user_defined_parameters.empty()) { csv_read << separator << data.options.user_defined_parameters; } csv_read << ");"; - output.SetValue(10, 0, csv_read.str()); + output.SetValue(11, 0, csv_read.str()); global_state.done = true; } diff --git a/src/duckdb/src/function/table/system/duckdb_constraints.cpp b/src/duckdb/src/function/table/system/duckdb_constraints.cpp index 71fabb16..8637feaf 100644 --- a/src/duckdb/src/function/table/system/duckdb_constraints.cpp +++ b/src/duckdb/src/function/table/system/duckdb_constraints.cpp @@ -1,52 +1,19 @@ -#include "duckdb/function/table/system_functions.hpp" - #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/constraints/check_constraint.hpp" +#include "duckdb/parser/constraints/foreign_key_constraint.hpp" +#include "duckdb/parser/constraints/not_null_constraint.hpp" #include "duckdb/parser/constraints/unique_constraint.hpp" -#include "duckdb/planner/constraints/bound_unique_constraint.hpp" -#include "duckdb/planner/constraints/bound_check_constraint.hpp" -#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" -#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" -#include "duckdb/storage/data_table.hpp" #include "duckdb/planner/binder.hpp" - -namespace duckdb { - -struct UniqueKeyInfo { - string schema; - string table; - vector columns; - - bool operator==(const UniqueKeyInfo &other) const { - return (schema == other.schema) && (table == other.table) && (columns == other.columns); - } -}; - -} // namespace duckdb - -namespace std { - -template <> -struct hash { - template - static size_t ComputeHash(const X &x) { - return hash()(x); - } - - size_t operator()(const duckdb::UniqueKeyInfo &j) const { - D_ASSERT(j.columns.size() > 0); - return ComputeHash(j.schema) + ComputeHash(j.table) + ComputeHash(j.columns[0].index); - } -}; - -} // namespace std +#include "duckdb/planner/constraints/bound_check_constraint.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" namespace duckdb { @@ -71,7 +38,7 @@ struct DuckDBConstraintsData : public GlobalTableFunctionState { idx_t offset; idx_t constraint_offset; idx_t unique_constraint_offset; - unordered_map known_fk_unique_constraint_offsets; + case_insensitive_set_t constraint_names; }; static unique_ptr DuckDBConstraintsBind(ClientContext &context, TableFunctionBindInput &input, @@ -113,6 +80,16 @@ static unique_ptr DuckDBConstraintsBind(ClientContext &context, Ta names.emplace_back("constraint_column_names"); return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); + names.emplace_back("constraint_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + // FOREIGN KEY + names.emplace_back("referenced_table"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("referenced_column_names"); + return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); + return nullptr; } @@ -140,6 +117,97 @@ unique_ptr DuckDBConstraintsInit(ClientContext &contex return std::move(result); } +struct ExtraConstraintInfo { + vector column_indexes; + vector column_names; + string referenced_table; + vector referenced_columns; +}; + +void ExtractReferencedColumns(const ParsedExpression &expr, vector &result) { + if (expr.GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &colref = expr.Cast(); + result.push_back(colref.GetColumnName()); + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](const ParsedExpression &child) { ExtractReferencedColumns(child, result); }); +} + +ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const Constraint &constraint) { + ExtraConstraintInfo result; + switch (constraint.type) { + case ConstraintType::CHECK: { + auto &check_constraint = constraint.Cast(); + ExtractReferencedColumns(*check_constraint.expression, result.column_names); + break; + } + case ConstraintType::NOT_NULL: { + auto ¬_null_constraint = constraint.Cast(); + result.column_indexes.push_back(not_null_constraint.index); + break; + } + case ConstraintType::UNIQUE: { + auto &unique = constraint.Cast(); + if (unique.HasIndex()) { + result.column_indexes.push_back(unique.GetIndex()); + } else { + result.column_names = unique.GetColumnNames(); + } + break; + } + case ConstraintType::FOREIGN_KEY: { + auto &fk = constraint.Cast(); + result.referenced_columns = fk.pk_columns; + result.referenced_table = fk.info.table; + result.column_names = fk.fk_columns; + break; + } + default: + throw InternalException("Unsupported type for constraint name"); + } + if (result.column_indexes.empty()) { + // generate column indexes from names + for (auto &name : result.column_names) { + result.column_indexes.push_back(table.GetColumnIndex(name)); + } + } else { + // generate names from column indexes + for (auto &index : result.column_indexes) { + result.column_names.push_back(table.GetColumn(index).GetName()); + } + } + return result; +} + +string GetConstraintName(const TableCatalogEntry &table, Constraint &constraint, const ExtraConstraintInfo &info) { + string result = table.name + "_"; + for (auto &col : info.column_names) { + result += StringUtil::Lower(col) + "_"; + } + for (auto &col : info.referenced_columns) { + result += StringUtil::Lower(col) + "_"; + } + switch (constraint.type) { + case ConstraintType::CHECK: + result += "check"; + break; + case ConstraintType::NOT_NULL: + result += "not_null"; + break; + case ConstraintType::UNIQUE: { + auto &unique = constraint.Cast(); + result += unique.IsPrimaryKey() ? "pkey" : "key"; + break; + } + case ConstraintType::FOREIGN_KEY: + result += "fkey"; + break; + default: + throw InternalException("Unsupported type for constraint name"); + } + return result; +} + void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &data = data_p.global_state->Cast(); if (data.offset >= data.entries.size()) { @@ -154,7 +222,6 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ auto &table = entry.table; auto &constraints = table.GetConstraints(); - bool is_duck_table = table.IsDuckTable(); for (; data.constraint_offset < constraints.size() && count < STANDARD_VECTOR_SIZE; data.constraint_offset++) { auto &constraint = constraints[data.constraint_offset]; // return values: @@ -174,12 +241,8 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ constraint_type = "NOT NULL"; break; case ConstraintType::FOREIGN_KEY: { - if (!is_duck_table) { - continue; - } - auto &bound_constraints = entry.bound_constraints; - auto &bound_foreign_key = bound_constraints[data.constraint_offset]->Cast(); - if (bound_foreign_key.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { + auto &fk = constraint->Cast(); + if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { // Those are already covered by PRIMARY KEY and UNIQUE entries continue; } @@ -204,52 +267,21 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ // table_oid, LogicalType::BIGINT output.SetValue(col++, count, Value::BIGINT(NumericCast(table.oid))); - // constraint_index, BIGINT - UniqueKeyInfo uk_info; - - if (is_duck_table) { - auto &bound_constraint = *entry.bound_constraints[data.constraint_offset]; - switch (bound_constraint.type) { - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraint.Cast(); - uk_info = {table.schema.name, table.name, bound_unique.keys}; - break; - } - case ConstraintType::FOREIGN_KEY: { - const auto &bound_foreign_key = bound_constraint.Cast(); - const auto &info = bound_foreign_key.info; - // find the other table - auto table_entry = Catalog::GetEntry( - context, table.catalog.GetName(), info.schema, info.table, OnEntryNotFound::RETURN_NULL); - if (!table_entry) { - throw InternalException("dukdb_constraints: entry %s.%s referenced in foreign key not found", - info.schema, info.table); - } - vector index; - for (auto &key : info.pk_keys) { - index.push_back(table_entry->GetColumns().PhysicalToLogical(key)); - } - uk_info = {table_entry->schema.name, table_entry->name, index}; - break; - } - default: - break; + auto info = GetExtraConstraintInfo(table, *constraint); + auto constraint_name = GetConstraintName(table, *constraint, info); + if (data.constraint_names.find(constraint_name) != data.constraint_names.end()) { + // duplicate constraint name + idx_t index = 2; + while (data.constraint_names.find(constraint_name + "_" + to_string(index)) != + data.constraint_names.end()) { + index++; } + constraint_name += "_" + to_string(index); } + // constraint_index, BIGINT + output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset++))); - if (uk_info.columns.empty()) { - output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset++))); - } else { - auto known_unique_constraint_offset = data.known_fk_unique_constraint_offsets.find(uk_info); - if (known_unique_constraint_offset == data.known_fk_unique_constraint_offsets.end()) { - data.known_fk_unique_constraint_offsets.insert(make_pair(uk_info, data.unique_constraint_offset)); - output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset))); - data.unique_constraint_offset++; - } else { - output.SetValue(col++, count, - Value::BIGINT(NumericCast(known_unique_constraint_offset->second))); - } - } + // constraint_type, VARCHAR output.SetValue(col++, count, Value(constraint_type)); // constraint_text, VARCHAR @@ -263,54 +295,34 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ } output.SetValue(col++, count, expression_text); - vector column_index_list; - if (is_duck_table) { - auto &bound_constraint = *entry.bound_constraints[data.constraint_offset]; - switch (bound_constraint.type) { - case ConstraintType::CHECK: { - auto &bound_check = bound_constraint.Cast(); - for (auto &col_idx : bound_check.bound_columns) { - column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); - } - break; - } - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraint.Cast(); - for (auto &col_idx : bound_unique.keys) { - column_index_list.push_back(col_idx); - } - break; - } - case ConstraintType::NOT_NULL: { - auto &bound_not_null = bound_constraint.Cast(); - column_index_list.push_back(table.GetColumns().PhysicalToLogical(bound_not_null.index)); - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &bound_foreign_key = bound_constraint.Cast(); - for (auto &col_idx : bound_foreign_key.info.fk_keys) { - column_index_list.push_back(table.GetColumns().PhysicalToLogical(col_idx)); - } - break; - } - default: - throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); - } - } - - vector index_list; + vector column_index_list; vector column_name_list; - for (auto column_index : column_index_list) { - index_list.push_back(Value::BIGINT(NumericCast(column_index.index))); - column_name_list.emplace_back(table.GetColumn(column_index).Name()); + vector referenced_column_name_list; + for (auto &col_index : info.column_indexes) { + column_index_list.push_back(Value::UBIGINT(col_index.index)); + } + for (auto &name : info.column_names) { + column_name_list.push_back(Value(std::move(name))); + } + for (auto &name : info.referenced_columns) { + referenced_column_name_list.push_back(Value(std::move(name))); } // constraint_column_indexes, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::BIGINT, std::move(index_list))); + output.SetValue(col++, count, Value::LIST(LogicalType::BIGINT, std::move(column_index_list))); // constraint_column_names, LIST output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(column_name_list))); + // constraint_name, VARCHAR + output.SetValue(col++, count, Value(std::move(constraint_name))); + + // referenced_table, VARCHAR + output.SetValue(col++, count, + info.referenced_table.empty() ? Value() : Value(std::move(info.referenced_table))); + + // referenced_column_names, LIST + output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(referenced_column_name_list))); count++; } if (data.constraint_offset >= constraints.size()) { diff --git a/src/duckdb/src/function/table/system/duckdb_extensions.cpp b/src/duckdb/src/function/table/system/duckdb_extensions.cpp index 3c7f2396..c61fa743 100644 --- a/src/duckdb/src/function/table/system/duckdb_extensions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_extensions.cpp @@ -138,21 +138,32 @@ unique_ptr DuckDBExtensionsInit(ClientContext &context #endif // Finally, we check the list of currently loaded extensions - auto &loaded_extensions = db.LoadedExtensionsData(); - for (auto &e : loaded_extensions) { + auto &extensions = db.GetExtensions(); + for (auto &e : extensions) { + if (!e.second.is_loaded) { + continue; + } auto &ext_name = e.first; - auto &ext_info = e.second; - auto entry = installed_extensions.find(ext_name); - if (entry == installed_extensions.end() || !entry->second.installed) { - ExtensionInformation &info = installed_extensions[ext_name]; - info.name = ext_name; - info.loaded = true; - info.extension_version = ext_info.version; - info.installed = ext_info.mode == ExtensionInstallMode::STATICALLY_LINKED; - info.install_mode = ext_info.mode; - } else { - entry->second.loaded = true; - entry->second.extension_version = ext_info.version; + auto &ext_data = e.second; + if (auto &ext_install_info = ext_data.install_info) { + auto entry = installed_extensions.find(ext_name); + if (entry == installed_extensions.end() || !entry->second.installed) { + ExtensionInformation &info = installed_extensions[ext_name]; + info.name = ext_name; + info.loaded = true; + info.extension_version = ext_install_info->version; + info.installed = ext_install_info->mode == ExtensionInstallMode::STATICALLY_LINKED; + info.install_mode = ext_install_info->mode; + } else { + entry->second.loaded = true; + entry->second.extension_version = ext_install_info->version; + } + } + if (auto &ext_load_info = ext_data.load_info) { + auto entry = installed_extensions.find(ext_name); + if (entry != installed_extensions.end()) { + entry->second.description = ext_load_info->description; + } } } diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp index 1ee4b851..80ac7bcb 100644 --- a/src/duckdb/src/function/table/system/duckdb_functions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_functions.cpp @@ -222,7 +222,7 @@ struct AggregateFunctionExtractor { struct MacroExtractor { static idx_t FunctionCount(ScalarMacroCatalogEntry &entry) { - return 1; + return entry.macros.size(); } static Value GetFunctionType() { @@ -235,12 +235,13 @@ struct MacroExtractor { static vector GetParameters(ScalarMacroCatalogEntry &entry, idx_t offset) { vector results; - for (auto ¶m : entry.function->parameters) { + auto ¯o_entry = *entry.macros[offset]; + for (auto ¶m : macro_entry.parameters) { D_ASSERT(param->type == ExpressionType::COLUMN_REF); auto &colref = param->Cast(); results.emplace_back(colref.GetColumnName()); } - for (auto ¶m_entry : entry.function->default_parameters) { + for (auto ¶m_entry : macro_entry.default_parameters) { results.emplace_back(param_entry.first); } return results; @@ -248,10 +249,11 @@ struct MacroExtractor { static Value GetParameterTypes(ScalarMacroCatalogEntry &entry, idx_t offset) { vector results; - for (idx_t i = 0; i < entry.function->parameters.size(); i++) { + auto ¯o_entry = *entry.macros[offset]; + for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { results.emplace_back(LogicalType::VARCHAR); } - for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { + for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { results.emplace_back(LogicalType::VARCHAR); } return Value::LIST(LogicalType::VARCHAR, std::move(results)); @@ -262,8 +264,9 @@ struct MacroExtractor { } static Value GetMacroDefinition(ScalarMacroCatalogEntry &entry, idx_t offset) { - D_ASSERT(entry.function->type == MacroType::SCALAR_MACRO); - auto &func = entry.function->Cast(); + auto ¯o_entry = *entry.macros[offset]; + D_ASSERT(macro_entry.type == MacroType::SCALAR_MACRO); + auto &func = macro_entry.Cast(); return func.expression->ToString(); } @@ -278,7 +281,7 @@ struct MacroExtractor { struct TableMacroExtractor { static idx_t FunctionCount(TableMacroCatalogEntry &entry) { - return 1; + return entry.macros.size(); } static Value GetFunctionType() { @@ -291,12 +294,13 @@ struct TableMacroExtractor { static vector GetParameters(TableMacroCatalogEntry &entry, idx_t offset) { vector results; - for (auto ¶m : entry.function->parameters) { + auto ¯o_entry = *entry.macros[offset]; + for (auto ¶m : macro_entry.parameters) { D_ASSERT(param->type == ExpressionType::COLUMN_REF); auto &colref = param->Cast(); results.emplace_back(colref.GetColumnName()); } - for (auto ¶m_entry : entry.function->default_parameters) { + for (auto ¶m_entry : macro_entry.default_parameters) { results.emplace_back(param_entry.first); } return results; @@ -304,10 +308,11 @@ struct TableMacroExtractor { static Value GetParameterTypes(TableMacroCatalogEntry &entry, idx_t offset) { vector results; - for (idx_t i = 0; i < entry.function->parameters.size(); i++) { + auto ¯o_entry = *entry.macros[offset]; + for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { results.emplace_back(LogicalType::VARCHAR); } - for (idx_t i = 0; i < entry.function->default_parameters.size(); i++) { + for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { results.emplace_back(LogicalType::VARCHAR); } return Value::LIST(LogicalType::VARCHAR, std::move(results)); @@ -318,8 +323,9 @@ struct TableMacroExtractor { } static Value GetMacroDefinition(TableMacroCatalogEntry &entry, idx_t offset) { - if (entry.function->type == MacroType::TABLE_MACRO) { - auto &func = entry.function->Cast(); + auto ¯o_entry = *entry.macros[offset]; + if (macro_entry.type == MacroType::TABLE_MACRO) { + auto &func = macro_entry.Cast(); return func.query_node->ToString(); } return Value(); diff --git a/src/duckdb/src/function/table/system/duckdb_indexes.cpp b/src/duckdb/src/function/table/system/duckdb_indexes.cpp index 80938f60..78cc88d8 100644 --- a/src/duckdb/src/function/table/system/duckdb_indexes.cpp +++ b/src/duckdb/src/function/table/system/duckdb_indexes.cpp @@ -75,6 +75,20 @@ unique_ptr DuckDBIndexesInit(ClientContext &context, T return std::move(result); } +Value GetIndexExpressions(IndexCatalogEntry &index) { + auto create_info = index.GetInfo(); + auto &create_index_info = create_info->Cast(); + + auto vec = create_index_info.ExpressionsToList(); + + vector content; + content.reserve(vec.size()); + for (auto &item : vec) { + content.push_back(Value(item)); + } + return Value::LIST(LogicalType::VARCHAR, std::move(content)); +} + void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &data = data_p.global_state->Cast(); if (data.offset >= data.entries.size()) { @@ -119,7 +133,7 @@ void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, D // is_primary, BOOLEAN output.SetValue(col++, count, Value::BOOLEAN(index.IsPrimary())); // expressions, VARCHAR - output.SetValue(col++, count, Value()); + output.SetValue(col++, count, GetIndexExpressions(index).ToString()); // sql, VARCHAR auto sql = index.ToSQL(); output.SetValue(col++, count, sql.empty() ? Value() : Value(std::move(sql))); diff --git a/src/duckdb/src/function/table/system/duckdb_variables.cpp b/src/duckdb/src/function/table/system/duckdb_variables.cpp new file mode 100644 index 00000000..62cfbcb8 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_variables.cpp @@ -0,0 +1,84 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_data.hpp" + +namespace duckdb { + +struct VariableData { + string name; + Value value; +}; + +struct DuckDBVariablesData : public GlobalTableFunctionState { + DuckDBVariablesData() : offset(0) { + } + + vector variables; + idx_t offset; +}; + +static unique_ptr DuckDBVariablesBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("value"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("type"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBVariablesInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + auto &config = ClientConfig::GetConfig(context); + + for (auto &entry : config.user_variables) { + VariableData data; + data.name = entry.first; + data.value = entry.second; + result->variables.push_back(std::move(data)); + } + return std::move(result); +} + +void DuckDBVariablesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.variables.size()) { + // finished returning values + return; + } + // start returning values + // either fill up the chunk or return all the remaining columns + idx_t count = 0; + while (data.offset < data.variables.size() && count < STANDARD_VECTOR_SIZE) { + auto &variable_entry = data.variables[data.offset++]; + + // return values: + idx_t col = 0; + // name, VARCHAR + output.SetValue(col++, count, Value(variable_entry.name)); + // value, BIGINT + output.SetValue(col++, count, Value(variable_entry.value.ToString())); + // type, VARCHAR + output.SetValue(col, count, Value(variable_entry.value.type().ToString())); + count++; + } + output.SetCardinality(count); +} + +void DuckDBVariablesFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_variables", {}, DuckDBVariablesFunction, DuckDBVariablesBind, DuckDBVariablesInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index 94450345..2b1e0803 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -31,6 +31,7 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { result.emplace_back(LogicalType::USMALLINT, "usmallint"); result.emplace_back(LogicalType::UINTEGER, "uint"); result.emplace_back(LogicalType::UBIGINT, "ubigint"); + result.emplace_back(LogicalType::VARINT, "varint"); result.emplace_back(LogicalType::DATE, "date"); result.emplace_back(LogicalType::TIME, "time"); result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp index 7562a031..23dab875 100644 --- a/src/duckdb/src/function/table/system/test_vector_types.cpp +++ b/src/duckdb/src/function/table/system/test_vector_types.cpp @@ -78,6 +78,14 @@ struct TestVectorFlat { break; } case PhysicalType::LIST: { + if (type.id() == LogicalTypeId::MAP) { + auto &child_type = ListType::GetChildType(type); + auto child_values = GenerateValues(info, child_type); + result.push_back(Value::MAP(child_type, {child_values[0]})); + result.push_back(Value(type)); + result.push_back(Value::MAP(child_type, {child_values[1]})); + break; + } auto &child_type = ListType::GetChildType(type); auto child_values = GenerateValues(info, child_type); @@ -157,6 +165,9 @@ struct TestVectorSequence { case LogicalTypeId::UINTEGER: case LogicalTypeId::UBIGINT: result.Sequence(3, 2, 3); +#if STANDARD_VECTOR_SIZE <= 2 + result.Flatten(3); +#endif return; default: break; @@ -170,6 +181,7 @@ struct TestVectorSequence { break; } case PhysicalType::LIST: { + D_ASSERT(type.id() != LogicalTypeId::MAP); auto data = FlatVector::GetData(result); data[0].offset = 0; data[0].length = 2; @@ -196,15 +208,33 @@ struct TestVectorSequence { } static void Generate(TestVectorInfo &info) { -#if STANDARD_VECTOR_SIZE > 2 + static constexpr const idx_t SEQ_CARDINALITY = 3; + auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); + result->Initialize(Allocator::DefaultAllocator(), info.types, + MaxValue(SEQ_CARDINALITY, STANDARD_VECTOR_SIZE)); for (idx_t c = 0; c < info.types.size(); c++) { + if (info.types[c].id() == LogicalTypeId::MAP) { + // FIXME: we don't support MAP in the TestVectorSequence + return; + } GenerateVector(info, info.types[c], result->data[c]); } - result->SetCardinality(3); + result->SetCardinality(SEQ_CARDINALITY); +#if STANDARD_VECTOR_SIZE > 2 info.entries.push_back(std::move(result)); +#else + // vsize = 2, split into two smaller data chunks + for (idx_t offset = 0; offset < SEQ_CARDINALITY; offset += STANDARD_VECTOR_SIZE) { + auto new_result = make_uniq(); + new_result->Initialize(Allocator::DefaultAllocator(), info.types); + + idx_t copy_count = MinValue(STANDARD_VECTOR_SIZE, SEQ_CARDINALITY - offset); + result->Copy(*new_result, *FlatVector::IncrementalSelectionVector(), offset + copy_count, offset); + + info.entries.push_back(std::move(new_result)); + } #endif } }; diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index a1a821db..12e8bcc3 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -36,6 +36,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { DuckDBTablesFun::RegisterFunction(*this); DuckDBTemporaryFilesFun::RegisterFunction(*this); DuckDBTypesFun::RegisterFunction(*this); + DuckDBVariablesFun::RegisterFunction(*this); DuckDBViewsFun::RegisterFunction(*this); TestAllTypesFun::RegisterFunction(*this); TestVectorTypesFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index d6c7bdde..7733bf2b 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -86,7 +86,6 @@ static unique_ptr TableScanInitLocal(ExecutionContext & } unique_ptr TableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - D_ASSERT(input.bind_data); auto &bind_data = input.bind_data->Cast(); auto result = make_uniq(context, input.bind_data.get()); @@ -167,7 +166,7 @@ double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p } idx_t scanned_rows = gstate.state.scan_state.processed_rows; scanned_rows += gstate.state.local_state.processed_rows; - auto percentage = 100 * (double(scanned_rows) / total_rows); + auto percentage = 100 * (static_cast(scanned_rows) / static_cast(total_rows)); if (percentage > 100) { //! In case the last chunk has less elements than STANDARD_VECTOR_SIZE, if our percentage is over 100 //! It means we finished this table. @@ -211,10 +210,13 @@ unique_ptr TableScanCardinality(ClientContext &context, const Fu // Index Scan //===--------------------------------------------------------------------===// struct IndexScanGlobalState : public GlobalTableFunctionState { - explicit IndexScanGlobalState(data_ptr_t row_id_data) : row_ids(LogicalType::ROW_TYPE, row_id_data) { + IndexScanGlobalState(const data_ptr_t row_id_data, const idx_t count) + : row_ids(LogicalType::ROW_TYPE, row_id_data), row_ids_count(count), row_ids_offset(0) { } - Vector row_ids; + const Vector row_ids; + const idx_t row_ids_count; + idx_t row_ids_offset; ColumnFetchState fetch_state; TableScanState local_storage_state; vector column_ids; @@ -223,19 +225,21 @@ struct IndexScanGlobalState : public GlobalTableFunctionState { static unique_ptr IndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { auto &bind_data = input.bind_data->Cast(); + data_ptr_t row_id_data = nullptr; - if (!bind_data.result_ids.empty()) { - row_id_data = (data_ptr_t)&bind_data.result_ids[0]; // NOLINT - this is not pretty + if (!bind_data.row_ids.empty()) { + row_id_data = (data_ptr_t)&bind_data.row_ids[0]; // NOLINT - this is not pretty } - auto result = make_uniq(row_id_data); + + auto result = make_uniq(row_id_data, bind_data.row_ids.size()); auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); result->local_storage_state.options.force_fetch_row = ClientConfig::GetConfig(context).force_fetch_row; - result->column_ids.reserve(input.column_ids.size()); for (auto &id : input.column_ids) { result->column_ids.push_back(GetStorageIndex(bind_data.table, id)); } + result->local_storage_state.Initialize(result->column_ids, input.filters.get()); local_storage.InitializeScan(bind_data.table.GetStorage(), result->local_storage_state.local_state, input.filters); @@ -250,9 +254,17 @@ static void IndexScanFunction(ClientContext &context, TableFunctionInput &data_p auto &local_storage = LocalStorage::Get(transaction); if (!state.finished) { - bind_data.table.GetStorage().Fetch(transaction, output, state.column_ids, state.row_ids, - bind_data.result_ids.size(), state.fetch_state); - state.finished = true; + auto remaining = state.row_ids_count - state.row_ids_offset; + auto scan_count = remaining < STANDARD_VECTOR_SIZE ? remaining : STANDARD_VECTOR_SIZE; + + Vector row_ids(state.row_ids, state.row_ids_offset, state.row_ids_offset + scan_count); + bind_data.table.GetStorage().Fetch(transaction, output, state.column_ids, row_ids, scan_count, + state.fetch_state); + + state.row_ids_offset += scan_count; + if (state.row_ids_offset == state.row_ids_count) { + state.finished = true; + } } if (output.size() == 0) { local_storage.Scan(state.local_storage_state.local_state, state.column_ids, output); @@ -265,10 +277,11 @@ static void RewriteIndexExpression(Index &index, LogicalGet &get, Expression &ex // bound column ref: rewrite to fit in the current set of bound column ids bound_colref.binding.table_index = get.table_index; auto &column_ids = index.GetColumnIds(); + auto &get_column_ids = get.GetColumnIds(); column_t referenced_column = column_ids[bound_colref.binding.column_index]; // search for the referenced column in the set of column_ids - for (idx_t i = 0; i < get.column_ids.size(); i++) { - if (get.column_ids[i] == referenced_column) { + for (idx_t i = 0; i < get_column_ids.size(); i++) { + if (get_column_ids[i] == referenced_column) { bound_colref.binding.column_index = i; return; } @@ -310,7 +323,6 @@ void TableScanPushdownComplexFilter(ClientContext &context, LogicalGet &get, Fun auto checkpoint_lock = storage.GetSharedCheckpointLock(); auto &info = storage.GetDataTableInfo(); - auto &transaction = Transaction::Get(context, bind_data.table.catalog); // bind and scan any ART indexes info->GetIndexes().BindAndScan(context, *info, [&](ART &art_index) { @@ -328,17 +340,28 @@ void TableScanPushdownComplexFilter(ClientContext &context, LogicalGet &get, Fun return false; } - // try to find a matching index for any of the filter expressions + // Try to find a matching index for any of the filter expressions. for (auto &filter : filters) { - auto index_state = art_index.TryInitializeScan(transaction, *index_expression, *filter); + auto index_state = art_index.TryInitializeScan(*index_expression, *filter); if (index_state != nullptr) { - if (art_index.Scan(transaction, storage, *index_state, STANDARD_VECTOR_SIZE, bind_data.result_ids)) { - // use an index scan! + + auto &db_config = DBConfig::GetConfig(context); + auto index_scan_percentage = db_config.options.index_scan_percentage; + auto index_scan_max_count = db_config.options.index_scan_max_count; + + auto total_rows = storage.GetTotalRows(); + auto total_rows_from_percentage = LossyNumericCast(double(total_rows) * index_scan_percentage); + auto max_count = MaxValue(index_scan_max_count, total_rows_from_percentage); + + // Check if we can use an index scan, and already retrieve the matching row ids. + if (art_index.Scan(*index_state, max_count, bind_data.row_ids)) { bind_data.is_index_scan = true; get.function = TableScanFunction::GetIndexScanFunction(); - } else { - bind_data.result_ids.clear(); + return true; } + + // Clear the row ids in case we exceeded the maximum count and stopped scanning. + bind_data.row_ids.clear(); return true; } } @@ -360,7 +383,7 @@ static void TableScanSerialize(Serializer &serializer, const optional_ptr TableScanDeserialize(Deserializer &deserializer, TableFunction &function) { @@ -375,7 +398,7 @@ static unique_ptr TableScanDeserialize(Deserializer &deserializer, auto result = make_uniq(catalog_entry.Cast()); deserializer.ReadProperty(103, "is_index_scan", result->is_index_scan); deserializer.ReadProperty(104, "is_create_index", result->is_create_index); - deserializer.ReadProperty(105, "result_ids", result->result_ids); + deserializer.ReadProperty(105, "result_ids", result->row_ids); return std::move(result); } diff --git a/src/duckdb/src/function/table/unnest.cpp b/src/duckdb/src/function/table/unnest.cpp index b6485bc3..7abdf9df 100644 --- a/src/duckdb/src/function/table/unnest.cpp +++ b/src/duckdb/src/function/table/unnest.cpp @@ -47,7 +47,7 @@ static unique_ptr UnnestBind(ClientContext &context, TableFunction throw BinderException("UNNEST requires a single list as input"); } return_types.push_back(ListType::GetChildType(input.input_table_types[0])); - names.push_back(input.input_table_names[0]); + names.push_back("unnest"); return make_uniq(input.input_table_types[0]); } @@ -78,7 +78,7 @@ static OperatorResultType UnnestFunction(ExecutionContext &context, TableFunctio } void UnnestTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction unnest_function("unnest", {LogicalTypeId::TABLE}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); + TableFunction unnest_function("unnest", {LogicalType::ANY}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); unnest_function.in_out_function = UnnestFunction; set.AddFunction(unnest_function); } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 70242dd5..d69fd257 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -2,21 +2,21 @@ #define DUCKDB_PATCH_VERSION "0" #endif #ifndef DUCKDB_MINOR_VERSION -#define DUCKDB_MINOR_VERSION 0 +#define DUCKDB_MINOR_VERSION 1 #endif #ifndef DUCKDB_MAJOR_VERSION #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.0.0" +#define DUCKDB_VERSION "v1.1.0" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "1f98600c2c" +#define DUCKDB_SOURCE_ID "fa5c2fe15f" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/common/platform.h" +#include "duckdb/common/platform.hpp" #include diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp index 802f89fb..a5190aee 100644 --- a/src/duckdb/src/function/table_function.cpp +++ b/src/duckdb/src/function/table_function.cpp @@ -18,8 +18,9 @@ TableFunction::TableFunction(string name, vector arguments, table_f init_global(init_global), init_local(init_local), function(function), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), - get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), serialize(nullptr), - deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false) { + get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), + serialize(nullptr), deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), + filter_prune(false) { } TableFunction::TableFunction(const vector &arguments, table_function_t function, @@ -32,8 +33,8 @@ TableFunction::TableFunction() init_local(nullptr), function(nullptr), in_out_function(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_batch_index(nullptr), get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), - serialize(nullptr), deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), - filter_prune(false) { + supports_pushdown_type(nullptr), serialize(nullptr), deserialize(nullptr), projection_pushdown(false), + filter_pushdown(false), filter_prune(false) { } bool TableFunction::Equal(const TableFunction &rhs) const { diff --git a/src/duckdb/src/function/table_macro_function.cpp b/src/duckdb/src/function/table_macro_function.cpp index 9fbb1792..becb1fe6 100644 --- a/src/duckdb/src/function/table_macro_function.cpp +++ b/src/duckdb/src/function/table_macro_function.cpp @@ -27,8 +27,8 @@ unique_ptr TableMacroFunction::Copy() const { return std::move(result); } -string TableMacroFunction::ToSQL(const string &schema, const string &name) const { - return MacroFunction::ToSQL(schema, name) + StringUtil::Format("TABLE (%s);", query_node->ToString()); +string TableMacroFunction::ToSQL() const { + return MacroFunction::ToSQL() + StringUtil::Format("TABLE (%s)", query_node->ToString()); } } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index 395befc0..b66253fc 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -6,17 +6,25 @@ // // //===----------------------------------------------------------------------===// +// +// !!!!!!! +// WARNING: this file is autogenerated by scripts/generate_c_api.py, manual changes will be overwritten +// !!!!!!! #pragma once //! duplicate of duckdb/main/winapi.hpp #ifndef DUCKDB_API #ifdef _WIN32 +#ifdef DUCKDB_STATIC_BUILD +#define DUCKDB_API +#else #if defined(DUCKDB_BUILD_LIBRARY) && !defined(DUCKDB_BUILD_LOADABLE_EXTENSION) #define DUCKDB_API __declspec(dllexport) #else #define DUCKDB_API __declspec(dllimport) #endif +#endif #else #define DUCKDB_API #endif @@ -25,41 +33,20 @@ //! duplicate of duckdb/main/winapi.hpp #ifndef DUCKDB_EXTENSION_API #ifdef _WIN32 +#ifdef DUCKDB_STATIC_BUILD +#define DUCKDB_EXTENSION_API +#else #ifdef DUCKDB_BUILD_LOADABLE_EXTENSION #define DUCKDB_EXTENSION_API __declspec(dllexport) #else #define DUCKDB_EXTENSION_API #endif +#endif #else #define DUCKDB_EXTENSION_API __attribute__((visibility("default"))) #endif #endif -//! In the future, we are planning to move extension functions to a separate header. For now you can set the define -//! below to remove the functions that are planned to be moved out of this header. -// #define DUCKDB_NO_EXTENSION_FUNCTIONS - -//! Set the define below to remove all functions that are deprecated or planned to be deprecated -// #define DUCKDB_API_NO_DEPRECATED - -//! API versions -//! If no explicit API version is defined, the latest API version is used. -//! Note that using older API versions (i.e. not using DUCKDB_API_LATEST) is deprecated. -//! These will not be supported long-term, and will be removed in future versions. -#ifndef DUCKDB_API_0_3_1 -#define DUCKDB_API_0_3_1 1 -#endif -#ifndef DUCKDB_API_0_3_2 -#define DUCKDB_API_0_3_2 2 -#endif -#ifndef DUCKDB_API_LATEST -#define DUCKDB_API_LATEST DUCKDB_API_0_3_2 -#endif - -#ifndef DUCKDB_API_VERSION -#define DUCKDB_API_VERSION DUCKDB_API_LATEST -#endif - #include #include #include @@ -142,25 +129,29 @@ typedef enum DUCKDB_TYPE { DUCKDB_TYPE_TIME_TZ = 30, // duckdb_timestamp DUCKDB_TYPE_TIMESTAMP_TZ = 31, + // ANY type + DUCKDB_TYPE_ANY = 34, + // duckdb_varint + DUCKDB_TYPE_VARINT = 35, } duckdb_type; //! An enum over the returned state of different functions. -typedef enum { DuckDBSuccess = 0, DuckDBError = 1 } duckdb_state; +typedef enum duckdb_state { DuckDBSuccess = 0, DuckDBError = 1 } duckdb_state; //! An enum over the pending state of a pending query result. -typedef enum { +typedef enum duckdb_pending_state { DUCKDB_PENDING_RESULT_READY = 0, DUCKDB_PENDING_RESULT_NOT_READY = 1, DUCKDB_PENDING_ERROR = 2, DUCKDB_PENDING_NO_TASKS_AVAILABLE = 3 } duckdb_pending_state; //! An enum over DuckDB's different result types. -typedef enum { +typedef enum duckdb_result_type { DUCKDB_RESULT_TYPE_INVALID = 0, DUCKDB_RESULT_TYPE_CHANGED_ROWS = 1, DUCKDB_RESULT_TYPE_NOTHING = 2, DUCKDB_RESULT_TYPE_QUERY_RESULT = 3, } duckdb_result_type; //! An enum over DuckDB's different statement types. -typedef enum { +typedef enum duckdb_statement_type { DUCKDB_STATEMENT_TYPE_INVALID = 0, DUCKDB_STATEMENT_TYPE_SELECT = 1, DUCKDB_STATEMENT_TYPE_INSERT = 2, @@ -190,6 +181,54 @@ typedef enum { DUCKDB_STATEMENT_TYPE_DETACH = 26, DUCKDB_STATEMENT_TYPE_MULTI = 27, } duckdb_statement_type; +//! An enum over DuckDB's different result types. +typedef enum duckdb_error_type { + DUCKDB_ERROR_INVALID = 0, + DUCKDB_ERROR_OUT_OF_RANGE = 1, + DUCKDB_ERROR_CONVERSION = 2, + DUCKDB_ERROR_UNKNOWN_TYPE = 3, + DUCKDB_ERROR_DECIMAL = 4, + DUCKDB_ERROR_MISMATCH_TYPE = 5, + DUCKDB_ERROR_DIVIDE_BY_ZERO = 6, + DUCKDB_ERROR_OBJECT_SIZE = 7, + DUCKDB_ERROR_INVALID_TYPE = 8, + DUCKDB_ERROR_SERIALIZATION = 9, + DUCKDB_ERROR_TRANSACTION = 10, + DUCKDB_ERROR_NOT_IMPLEMENTED = 11, + DUCKDB_ERROR_EXPRESSION = 12, + DUCKDB_ERROR_CATALOG = 13, + DUCKDB_ERROR_PARSER = 14, + DUCKDB_ERROR_PLANNER = 15, + DUCKDB_ERROR_SCHEDULER = 16, + DUCKDB_ERROR_EXECUTOR = 17, + DUCKDB_ERROR_CONSTRAINT = 18, + DUCKDB_ERROR_INDEX = 19, + DUCKDB_ERROR_STAT = 20, + DUCKDB_ERROR_CONNECTION = 21, + DUCKDB_ERROR_SYNTAX = 22, + DUCKDB_ERROR_SETTINGS = 23, + DUCKDB_ERROR_BINDER = 24, + DUCKDB_ERROR_NETWORK = 25, + DUCKDB_ERROR_OPTIMIZER = 26, + DUCKDB_ERROR_NULL_POINTER = 27, + DUCKDB_ERROR_IO = 28, + DUCKDB_ERROR_INTERRUPT = 29, + DUCKDB_ERROR_FATAL = 30, + DUCKDB_ERROR_INTERNAL = 31, + DUCKDB_ERROR_INVALID_INPUT = 32, + DUCKDB_ERROR_OUT_OF_MEMORY = 33, + DUCKDB_ERROR_PERMISSION = 34, + DUCKDB_ERROR_PARAMETER_NOT_RESOLVED = 35, + DUCKDB_ERROR_PARAMETER_NOT_ALLOWED = 36, + DUCKDB_ERROR_DEPENDENCY = 37, + DUCKDB_ERROR_HTTP = 38, + DUCKDB_ERROR_MISSING_EXTENSION = 39, + DUCKDB_ERROR_AUTOLOAD = 40, + DUCKDB_ERROR_SEQUENCE = 41, + DUCKDB_INVALID_CONFIGURATION = 42 +} duckdb_error_type; +//! An enum over DuckDB's different cast modes. +typedef enum duckdb_cast_mode { DUCKDB_CAST_NORMAL = 0, DUCKDB_CAST_TRY = 1 } duckdb_cast_mode; //===--------------------------------------------------------------------===// // General type definitions @@ -313,28 +352,21 @@ typedef struct { //! duckdb_column_type, and duckdb_column_name, which take the result and the column index //! as their parameters typedef struct { -#if DUCKDB_API_VERSION < DUCKDB_API_0_3_2 - void *data; - bool *nullmask; - duckdb_type type; - char *name; -#else // deprecated, use duckdb_column_data - void *__deprecated_data; + void *deprecated_data; // deprecated, use duckdb_nullmask_data - bool *__deprecated_nullmask; + bool *deprecated_nullmask; // deprecated, use duckdb_column_type - duckdb_type __deprecated_type; + duckdb_type deprecated_type; // deprecated, use duckdb_column_name - char *__deprecated_name; -#endif + char *deprecated_name; void *internal_data; } duckdb_column; //! A vector to a specified column in a data chunk. Lives as long as the //! data chunk lives, i.e., must not be destroyed. typedef struct _duckdb_vector { - void *__vctr; + void *internal_ptr; } * duckdb_vector; //===--------------------------------------------------------------------===// @@ -358,100 +390,177 @@ typedef struct { //! A query result consists of a pointer to its internal data. //! Must be freed with 'duckdb_destroy_result'. typedef struct { -#if DUCKDB_API_VERSION < DUCKDB_API_0_3_2 - idx_t column_count; - idx_t row_count; - idx_t rows_changed; - duckdb_column *columns; - char *error_message; -#else // deprecated, use duckdb_column_count - idx_t __deprecated_column_count; + idx_t deprecated_column_count; // deprecated, use duckdb_row_count - idx_t __deprecated_row_count; + idx_t deprecated_row_count; // deprecated, use duckdb_rows_changed - idx_t __deprecated_rows_changed; + idx_t deprecated_rows_changed; // deprecated, use duckdb_column_*-family of functions - duckdb_column *__deprecated_columns; + duckdb_column *deprecated_columns; // deprecated, use duckdb_result_error - char *__deprecated_error_message; -#endif + char *deprecated_error_message; void *internal_data; } duckdb_result; //! A database object. Should be closed with `duckdb_close`. typedef struct _duckdb_database { - void *__db; + void *internal_ptr; } * duckdb_database; //! A connection to a duckdb database. Must be closed with `duckdb_disconnect`. typedef struct _duckdb_connection { - void *__conn; + void *internal_ptr; } * duckdb_connection; //! A prepared statement is a parameterized query that allows you to bind parameters to it. //! Must be destroyed with `duckdb_destroy_prepare`. typedef struct _duckdb_prepared_statement { - void *__prep; + void *internal_ptr; } * duckdb_prepared_statement; //! Extracted statements. Must be destroyed with `duckdb_destroy_extracted`. typedef struct _duckdb_extracted_statements { - void *__extrac; + void *internal_ptr; } * duckdb_extracted_statements; //! The pending result represents an intermediate structure for a query that is not yet fully executed. //! Must be destroyed with `duckdb_destroy_pending`. typedef struct _duckdb_pending_result { - void *__pend; + void *internal_ptr; } * duckdb_pending_result; //! The appender enables fast data loading into DuckDB. //! Must be destroyed with `duckdb_appender_destroy`. typedef struct _duckdb_appender { - void *__appn; + void *internal_ptr; } * duckdb_appender; +//! The table description allows querying info about the table. +//! Must be destroyed with `duckdb_table_description_destroy`. +typedef struct _duckdb_table_description { + void *internal_ptr; +} * duckdb_table_description; + //! Can be used to provide start-up options for the DuckDB instance. //! Must be destroyed with `duckdb_destroy_config`. typedef struct _duckdb_config { - void *__cnfg; + void *internal_ptr; } * duckdb_config; //! Holds an internal logical type. //! Must be destroyed with `duckdb_destroy_logical_type`. typedef struct _duckdb_logical_type { - void *__lglt; + void *internal_ptr; } * duckdb_logical_type; +//! Holds extra information used when registering a custom logical type. +//! Reserved for future use. +typedef struct _duckdb_create_type_info { + void *internal_ptr; +} * duckdb_create_type_info; + //! Contains a data chunk from a duckdb_result. //! Must be destroyed with `duckdb_destroy_data_chunk`. typedef struct _duckdb_data_chunk { - void *__dtck; + void *internal_ptr; } * duckdb_data_chunk; //! Holds a DuckDB value, which wraps a type. //! Must be destroyed with `duckdb_destroy_value`. typedef struct _duckdb_value { - void *__val; + void *internal_ptr; } * duckdb_value; +//! Holds a recursive tree that matches the query plan. +typedef struct _duckdb_profiling_info { + void *internal_ptr; +} * duckdb_profiling_info; + +//===--------------------------------------------------------------------===// +// C API Extension info +//===--------------------------------------------------------------------===// +//! Holds state during the C API extension intialization process +typedef struct _duckdb_extension_info { + void *internal_ptr; +} * duckdb_extension_info; + +//===--------------------------------------------------------------------===// +// Function types +//===--------------------------------------------------------------------===// +//! Additional function info. When setting this info, it is necessary to pass a destroy-callback function. +typedef struct _duckdb_function_info { + void *internal_ptr; +} * duckdb_function_info; + +//===--------------------------------------------------------------------===// +// Scalar function types +//===--------------------------------------------------------------------===// +//! A scalar function. Must be destroyed with `duckdb_destroy_scalar_function`. +typedef struct _duckdb_scalar_function { + void *internal_ptr; +} * duckdb_scalar_function; + +//! A scalar function set. Must be destroyed with `duckdb_destroy_scalar_function_set`. +typedef struct _duckdb_scalar_function_set { + void *internal_ptr; +} * duckdb_scalar_function_set; + +//! The main function of the scalar function. +typedef void (*duckdb_scalar_function_t)(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); + +//===--------------------------------------------------------------------===// +// Aggregate function types +//===--------------------------------------------------------------------===// +//! An aggregate function. Must be destroyed with `duckdb_destroy_aggregate_function`. +typedef struct _duckdb_aggregate_function { + void *internal_ptr; +} * duckdb_aggregate_function; + +//! A aggregate function set. Must be destroyed with `duckdb_destroy_aggregate_function_set`. +typedef struct _duckdb_aggregate_function_set { + void *internal_ptr; +} * duckdb_aggregate_function_set; + +//! Aggregate state +typedef struct _duckdb_aggregate_state { + void *internal_ptr; +} * duckdb_aggregate_state; + +//! Returns the aggregate state size +typedef idx_t (*duckdb_aggregate_state_size)(duckdb_function_info info); +//! Initialize the aggregate state +typedef void (*duckdb_aggregate_init_t)(duckdb_function_info info, duckdb_aggregate_state state); +//! Destroy aggregate state (optional) +typedef void (*duckdb_aggregate_destroy_t)(duckdb_aggregate_state *states, idx_t count); +//! Update a set of aggregate states with new values +typedef void (*duckdb_aggregate_update_t)(duckdb_function_info info, duckdb_data_chunk input, + duckdb_aggregate_state *states); +//! Combine aggregate states +typedef void (*duckdb_aggregate_combine_t)(duckdb_function_info info, duckdb_aggregate_state *source, + duckdb_aggregate_state *target, idx_t count); +//! Finalize aggregate states into a result vector +typedef void (*duckdb_aggregate_finalize_t)(duckdb_function_info info, duckdb_aggregate_state *source, + duckdb_vector result, idx_t count, idx_t offset); + //===--------------------------------------------------------------------===// // Table function types //===--------------------------------------------------------------------===// -#ifndef DUCKDB_NO_EXTENSION_FUNCTIONS //! A table function. Must be destroyed with `duckdb_destroy_table_function`. -typedef void *duckdb_table_function; +typedef struct _duckdb_table_function { + void *internal_ptr; +} * duckdb_table_function; //! The bind info of the function. When setting this info, it is necessary to pass a destroy-callback function. -typedef void *duckdb_bind_info; +typedef struct _duckdb_bind_info { + void *internal_ptr; +} * duckdb_bind_info; //! Additional function init info. When setting this info, it is necessary to pass a destroy-callback function. -typedef void *duckdb_init_info; - -//! Additional function info. When setting this info, it is necessary to pass a destroy-callback function. -typedef void *duckdb_function_info; +typedef struct _duckdb_init_info { + void *internal_ptr; +} * duckdb_init_info; //! The bind function of the table function. typedef void (*duckdb_table_function_bind_t)(duckdb_bind_info info); @@ -462,16 +571,29 @@ typedef void (*duckdb_table_function_init_t)(duckdb_init_info info); //! The main function of the table function. typedef void (*duckdb_table_function_t)(duckdb_function_info info, duckdb_data_chunk output); +//===--------------------------------------------------------------------===// +// Cast types +//===--------------------------------------------------------------------===// + +//! A cast function. Must be destroyed with `duckdb_destroy_cast_function`. +typedef struct _duckdb_cast_function { + void *internal_ptr; +} * duckdb_cast_function; + +typedef bool (*duckdb_cast_function_t)(duckdb_function_info info, idx_t count, duckdb_vector input, + duckdb_vector output); + //===--------------------------------------------------------------------===// // Replacement scan types //===--------------------------------------------------------------------===// //! Additional replacement scan info. When setting this info, it is necessary to pass a destroy-callback function. -typedef void *duckdb_replacement_scan_info; +typedef struct _duckdb_replacement_scan_info { + void *internal_ptr; +} * duckdb_replacement_scan_info; //! A replacement scan function that can be added to a database. typedef void (*duckdb_replacement_callback_t)(duckdb_replacement_scan_info info, const char *table_name, void *data); -#endif //===--------------------------------------------------------------------===// // Arrow-related types @@ -479,30 +601,43 @@ typedef void (*duckdb_replacement_callback_t)(duckdb_replacement_scan_info info, //! Holds an arrow query result. Must be destroyed with `duckdb_destroy_arrow`. typedef struct _duckdb_arrow { - void *__arrw; + void *internal_ptr; } * duckdb_arrow; //! Holds an arrow array stream. Must be destroyed with `duckdb_destroy_arrow_stream`. typedef struct _duckdb_arrow_stream { - void *__arrwstr; + void *internal_ptr; } * duckdb_arrow_stream; //! Holds an arrow schema. Remember to release the respective ArrowSchema object. typedef struct _duckdb_arrow_schema { - void *__arrs; + void *internal_ptr; } * duckdb_arrow_schema; //! Holds an arrow array. Remember to release the respective ArrowArray object. typedef struct _duckdb_arrow_array { - void *__arra; + void *internal_ptr; } * duckdb_arrow_array; +//===--------------------------------------------------------------------===// +// DuckDB extension access +//===--------------------------------------------------------------------===// +//! Passed to C API extension as parameter to the entrypoint +struct duckdb_extension_access { + //! Indicate that an error has occured + void (*set_error)(duckdb_extension_info info, const char *error); + //! Fetch the database from duckdb to register extensions to + duckdb_database *(*get_database)(duckdb_extension_info info); + //! Fetch the API + void *(*get_api)(duckdb_extension_info info, const char *version); +}; + //===--------------------------------------------------------------------===// // Functions //===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===// -// Open/Connect +// Open Connect //===--------------------------------------------------------------------===// /*! @@ -510,9 +645,9 @@ Creates a new database or opens an existing database file stored at the given pa If no path is given a new in-memory database is created instead. The instantiated database should be closed with 'duckdb_close'. -* path: Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. -* out_database: The result database object. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param path Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. +* @param out_database The result database object. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_open(const char *path, duckdb_database *out_database); @@ -520,12 +655,12 @@ DUCKDB_API duckdb_state duckdb_open(const char *path, duckdb_database *out_datab Extended version of duckdb_open. Creates a new database or opens an existing database file stored at the given path. The instantiated database should be closed with 'duckdb_close'. -* path: Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. -* out_database: The result database object. -* config: (Optional) configuration used to start up the database system. -* out_error: If set and the function returns DuckDBError, this will contain the reason why the start-up failed. +* @param path Path to the database file on disk, or `nullptr` or `:memory:` to open an in-memory database. +* @param out_database The result database object. +* @param config (Optional) configuration used to start up the database system. +* @param out_error If set and the function returns DuckDBError, this will contain the reason why the start-up failed. Note that the error must be freed using `duckdb_free`. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_open_ext(const char *path, duckdb_database *out_database, duckdb_config config, char **out_error); @@ -536,7 +671,7 @@ This should be called after you are done with any database allocated through `du Note that failing to call `duckdb_close` (in case of e.g. a program crash) will not cause data corruption. Still, it is recommended to always correctly close a database object after you are done with it. -* database: The database object to shut down. +* @param database The database object to shut down. */ DUCKDB_API void duckdb_close(duckdb_database *database); @@ -545,31 +680,31 @@ Opens a connection to a database. Connections are required to query the database associated with the connection. The instantiated connection should be closed using 'duckdb_disconnect'. -* database: The database file to connect to. -* out_connection: The result connection object. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param database The database file to connect to. +* @param out_connection The result connection object. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out_connection); /*! Interrupt running query -* connection: The connection to interrupt +* @param connection The connection to interrupt */ DUCKDB_API void duckdb_interrupt(duckdb_connection connection); /*! Get progress of the running query -* connection: The working connection -* returns: -1 if no progress or a percentage of the progress +* @param connection The working connection +* @return -1 if no progress or a percentage of the progress */ DUCKDB_API duckdb_query_progress_type duckdb_query_progress(duckdb_connection connection); /*! Closes the specified connection and de-allocates all memory allocated for that connection. -* connection: The connection to close. +* @param connection The connection to close. */ DUCKDB_API void duckdb_disconnect(duckdb_connection *connection); @@ -591,8 +726,11 @@ The duckdb_config must be destroyed using 'duckdb_destroy_config' This will always succeed unless there is a malloc failure. -* out_config: The result configuration object. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +Note that `duckdb_destroy_config` should always be called on the resulting config, even if the function returns +`DuckDBError`. + +* @param out_config The result configuration object. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_create_config(duckdb_config *out_config); @@ -601,7 +739,7 @@ This returns the total amount of configuration options available for usage with This should not be called in a loop as it internally loops over all the options. -* returns: The amount of config options available. +* @return The amount of config options available. */ DUCKDB_API size_t duckdb_config_count(); @@ -611,10 +749,10 @@ display configuration options. This will succeed unless `index` is out of range The result name or description MUST NOT be freed. -* index: The index of the configuration option (between 0 and `duckdb_config_count`) -* out_name: A name of the configuration flag. -* out_description: A description of the configuration flag. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param index The index of the configuration option (between 0 and `duckdb_config_count`) +* @param out_name A name of the configuration flag. +* @param out_description A description of the configuration flag. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description); @@ -626,17 +764,17 @@ In the source code, configuration options are defined in `config.cpp`. This can fail if either the name is invalid, or if the value provided for the option is invalid. -* duckdb_config: The configuration object to set the option on. -* name: The name of the configuration flag to set. -* option: The value to set the configuration flag to. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param config The configuration object to set the option on. +* @param name The name of the configuration flag to set. +* @param option The value to set the configuration flag to. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_set_config(duckdb_config config, const char *name, const char *option); /*! Destroys the specified configuration object and de-allocates all memory allocated for the object. -* config: The configuration object to destroy. +* @param config The configuration object to destroy. */ DUCKDB_API void duckdb_destroy_config(duckdb_config *config); @@ -652,17 +790,17 @@ If the query fails to execute, DuckDBError is returned and the error message can Note that after running `duckdb_query`, `duckdb_destroy_result` must be called on the result object even if the query fails, otherwise the error stored within the result will not be freed correctly. -* connection: The connection to perform the query in. -* query: The SQL query to run. -* out_result: The query result. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param connection The connection to perform the query in. +* @param query The SQL query to run. +* @param out_result The query result. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out_result); /*! Closes the result and de-allocates all memory allocated for that connection. -* result: The result to destroy. +* @param result The result to destroy. */ DUCKDB_API void duckdb_destroy_result(duckdb_result *result); @@ -672,9 +810,9 @@ automatically be destroyed when the result is destroyed. Returns `NULL` if the column is out of range. -* result: The result object to fetch the column name from. -* col: The column index. -* returns: The column name of the specified column. +* @param result The result object to fetch the column name from. +* @param col The column index. +* @return The column name of the specified column. */ DUCKDB_API const char *duckdb_column_name(duckdb_result *result, idx_t col); @@ -683,18 +821,18 @@ Returns the column type of the specified column. Returns `DUCKDB_TYPE_INVALID` if the column is out of range. -* result: The result object to fetch the column type from. -* col: The column index. -* returns: The column type of the specified column. +* @param result The result object to fetch the column type from. +* @param col The column index. +* @return The column type of the specified column. */ DUCKDB_API duckdb_type duckdb_column_type(duckdb_result *result, idx_t col); /*! Returns the statement type of the statement that was executed -* result: The result object to fetch the statement type from. - * returns: duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID - */ +* @param result The result object to fetch the statement type from. +* @return duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID +*/ DUCKDB_API duckdb_statement_type duckdb_result_statement_type(duckdb_result result); /*! @@ -704,17 +842,17 @@ The return type of this call should be destroyed with `duckdb_destroy_logical_ty Returns `NULL` if the column is out of range. -* result: The result object to fetch the column type from. -* col: The column index. -* returns: The logical column type of the specified column. +* @param result The result object to fetch the column type from. +* @param col The column index. +* @return The logical column type of the specified column. */ DUCKDB_API duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col); /*! Returns the number of columns present in a the result object. -* result: The result object. -* returns: The number of columns present in the result object. +* @param result The result object. +* @return The number of columns present in the result object. */ DUCKDB_API idx_t duckdb_column_count(duckdb_result *result); @@ -724,8 +862,8 @@ DUCKDB_API idx_t duckdb_column_count(duckdb_result *result); Returns the number of rows present in the result object. -* result: The result object. -* returns: The number of rows present in the result object. +* @param result The result object. +* @return The number of rows present in the result object. */ DUCKDB_API idx_t duckdb_row_count(duckdb_result *result); #endif @@ -734,8 +872,8 @@ DUCKDB_API idx_t duckdb_row_count(duckdb_result *result); Returns the number of rows changed by the query stored in the result. This is relevant only for INSERT/UPDATE/DELETE queries. For other queries the rows_changed will be 0. -* result: The result object. -* returns: The number of rows changed. +* @param result The result object. +* @return The number of rows changed. */ DUCKDB_API idx_t duckdb_rows_changed(duckdb_result *result); @@ -755,12 +893,14 @@ int32_t *data = (int32_t *) duckdb_column_data(&result, 0); printf("Data for row %d: %d\n", row, data[row]); ``` -* result: The result object to fetch the column data from. -* col: The column index. -* returns: The column data of the specified column. +* @param result The result object to fetch the column data from. +* @param col The column index. +* @return The column data of the specified column. */ DUCKDB_API void *duckdb_column_data(duckdb_result *result, idx_t col); +#endif +#ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATED**: Prefer using `duckdb_result_get_chunk` instead. @@ -778,9 +918,9 @@ if (nullmask[row]) { } ``` -* result: The result object to fetch the nullmask from. -* col: The column index. -* returns: The nullmask of the specified column. +* @param result The result object to fetch the nullmask from. +* @param col The column index. +* @return The nullmask of the specified column. */ DUCKDB_API bool *duckdb_nullmask_data(duckdb_result *result, idx_t col); #endif @@ -790,14 +930,24 @@ Returns the error message contained within the result. The error is only set if The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_result` is called. -* result: The result object to fetch the error from. -* returns: The error of the result. +* @param result The result object to fetch the error from. +* @return The error of the result. */ DUCKDB_API const char *duckdb_result_error(duckdb_result *result); +/*! +Returns the result error type contained within the result. The error is only set if `duckdb_query` returns +`DuckDBError`. + +* @param result The result object to fetch the error from. +* @return The error type of the result. +*/ +DUCKDB_API duckdb_error_type duckdb_result_error_type(duckdb_result *result); + //===--------------------------------------------------------------------===// // Result Functions //===--------------------------------------------------------------------===// + #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -814,9 +964,9 @@ mixed with the legacy result functions). Use `duckdb_result_chunk_count` to figure out how many chunks there are in the result. -* result: The result object to fetch the data chunk from. -* chunk_index: The chunk index to fetch from. -* returns: The resulting data chunk. Returns `NULL` if the chunk index is out of bounds. +* @param result The result object to fetch the data chunk from. +* @param chunk_index The chunk index to fetch from. +* @return The resulting data chunk. Returns `NULL` if the chunk index is out of bounds. */ DUCKDB_API duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t chunk_index); @@ -825,8 +975,8 @@ DUCKDB_API duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t Checks if the type of the internal result is StreamQueryResult. -* result: The result object to check. -* returns: Whether or not the result object is of the type StreamQueryResult +* @param result The result object to check. +* @return Whether or not the result object is of the type StreamQueryResult */ DUCKDB_API bool duckdb_result_is_streaming(duckdb_result result); @@ -835,159 +985,160 @@ DUCKDB_API bool duckdb_result_is_streaming(duckdb_result result); Returns the number of data chunks present in the result. -* result: The result object -* returns: Number of data chunks present in the result. +* @param result The result object +* @return Number of data chunks present in the result. */ DUCKDB_API idx_t duckdb_result_chunk_count(duckdb_result result); -#endif /*! Returns the return_type of the given result, or DUCKDB_RETURN_TYPE_INVALID on error -* result: The result object -* returns: The return_type - */ +* @param result The result object +* @return The return_type +*/ DUCKDB_API duckdb_result_type duckdb_result_return_type(duckdb_result result); -#ifndef DUCKDB_API_NO_DEPRECATED +#endif //===--------------------------------------------------------------------===// -// Safe fetch functions +// Safe Fetch Functions //===--------------------------------------------------------------------===// // These functions will perform conversions if necessary. // On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. // Note that these functions are slow since they perform bounds checking and conversion // For fast access of values prefer using `duckdb_result_get_chunk` - +#ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The boolean value at the specified location, or false if the value cannot be converted. - */ +* @return The boolean value at the specified location, or false if the value cannot be converted. +*/ DUCKDB_API bool duckdb_value_boolean(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The int8_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The int8_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API int8_t duckdb_value_int8(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The int16_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The int16_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API int16_t duckdb_value_int16(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The int32_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The int32_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API int32_t duckdb_value_int32(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The int64_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The int64_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API int64_t duckdb_value_int64(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_hugeint value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_hugeint value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_hugeint duckdb_value_hugeint(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_uhugeint value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_uhugeint value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_uhugeint duckdb_value_uhugeint(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_decimal value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_decimal value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_decimal duckdb_value_decimal(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The uint8_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The uint8_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API uint8_t duckdb_value_uint8(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The uint16_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The uint16_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API uint16_t duckdb_value_uint16(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The uint32_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The uint32_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API uint32_t duckdb_value_uint32(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The uint64_t value at the specified location, or 0 if the value cannot be converted. - */ +* @return The uint64_t value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API uint64_t duckdb_value_uint64(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The float value at the specified location, or 0 if the value cannot be converted. - */ +* @return The float value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API float duckdb_value_float(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The double value at the specified location, or 0 if the value cannot be converted. - */ +* @return The double value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API double duckdb_value_double(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_date value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_date value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_date duckdb_value_date(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_time value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_time value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_time duckdb_value_time(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_timestamp value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_timestamp value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_t col, idx_t row); /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The duckdb_interval value at the specified location, or 0 if the value cannot be converted. - */ +* @return The duckdb_interval value at the specified location, or 0 if the value cannot be converted. +*/ DUCKDB_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); /*! -* DEPRECATED: use duckdb_value_string instead. This function does not work correctly if the string contains null bytes. -* returns: The text value at the specified location as a null-terminated string, or nullptr if the value cannot be +**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null +bytes. + +* @return The text value at the specified location as a null-terminated string, or nullptr if the value cannot be converted. The result must be freed with `duckdb_free`. */ DUCKDB_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t row); @@ -995,16 +1146,18 @@ DUCKDB_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t ro /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: The string value at the specified location. Attempts to cast the result value to string. - * No support for nested types, and for other complex types. - * The resulting field "string.data" must be freed with `duckdb_free.` - */ +No support for nested types, and for other complex types. +The resulting field "string.data" must be freed with `duckdb_free.` + +* @return The string value at the specified location. Attempts to cast the result value to string. +*/ DUCKDB_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); /*! -* DEPRECATED: use duckdb_value_string_internal instead. This function does not work correctly if the string contains +**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains null bytes. -* returns: The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. + +* @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. The result must NOT be freed. @@ -1012,9 +1165,9 @@ The result must NOT be freed. DUCKDB_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); /*! -* DEPRECATED: use duckdb_value_string_internal instead. This function does not work correctly if the string contains +**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains null bytes. -* returns: The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. +* @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. The result must NOT be freed. @@ -1024,7 +1177,7 @@ DUCKDB_API duckdb_string duckdb_value_string_internal(duckdb_result *result, idx /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. -* returns: The duckdb_blob value at the specified location. Returns a blob with blob.data set to nullptr if the +* @return The duckdb_blob value at the specified location. Returns a blob with blob.data set to nullptr if the value cannot be converted. The resulting field "blob.data" must be freed with `duckdb_free.` */ DUCKDB_API duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t row); @@ -1032,11 +1185,11 @@ DUCKDB_API duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - * returns: Returns true if the value at the specified index is NULL, and false otherwise. - */ +* @return Returns true if the value at the specified index is NULL, and false otherwise. +*/ DUCKDB_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row); -#endif +#endif //===--------------------------------------------------------------------===// // Helpers //===--------------------------------------------------------------------===// @@ -1045,8 +1198,8 @@ DUCKDB_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row Allocate `size` bytes of memory using the duckdb internal malloc function. Any memory allocated in this manner should be freed using `duckdb_free`. -* size: The number of bytes to allocate. -* returns: A pointer to the allocated memory region. +* @param size The number of bytes to allocate. +* @return A pointer to the allocated memory region. */ DUCKDB_API void *duckdb_malloc(size_t size); @@ -1054,7 +1207,7 @@ DUCKDB_API void *duckdb_malloc(size_t size); Free a value returned from `duckdb_malloc`, `duckdb_value_varchar`, `duckdb_value_blob`, or `duckdb_value_string`. -* ptr: The memory region to de-allocate. +* @param ptr The memory region to de-allocate. */ DUCKDB_API void duckdb_free(void *ptr); @@ -1062,7 +1215,7 @@ DUCKDB_API void duckdb_free(void *ptr); The internal vector size used by DuckDB. This is the amount of tuples that will fit into a data chunk created by `duckdb_create_data_chunk`. -* returns: The vector size. +* @return The vector size. */ DUCKDB_API idx_t duckdb_vector_size(); @@ -1073,48 +1226,64 @@ This means that the data of the string does not have a separate allocation. */ DUCKDB_API bool duckdb_string_is_inlined(duckdb_string_t string); +/*! +Get the string length of a string_t + +* @param string The string to get the length of. +* @return The length. +*/ +DUCKDB_API uint32_t duckdb_string_t_length(duckdb_string_t string); + +/*! +Get a pointer to the string data of a string_t + +* @param string The string to get the pointer to. +* @return The pointer. +*/ +DUCKDB_API const char *duckdb_string_t_data(duckdb_string_t *string); + //===--------------------------------------------------------------------===// -// Date/Time/Timestamp Helpers +// Date Time Timestamp Helpers //===--------------------------------------------------------------------===// /*! Decompose a `duckdb_date` object into year, month and date (stored as `duckdb_date_struct`). -* date: The date object, as obtained from a `DUCKDB_TYPE_DATE` column. -* returns: The `duckdb_date_struct` with the decomposed elements. +* @param date The date object, as obtained from a `DUCKDB_TYPE_DATE` column. +* @return The `duckdb_date_struct` with the decomposed elements. */ DUCKDB_API duckdb_date_struct duckdb_from_date(duckdb_date date); /*! Re-compose a `duckdb_date` from year, month and date (`duckdb_date_struct`). -* date: The year, month and date stored in a `duckdb_date_struct`. -* returns: The `duckdb_date` element. +* @param date The year, month and date stored in a `duckdb_date_struct`. +* @return The `duckdb_date` element. */ DUCKDB_API duckdb_date duckdb_to_date(duckdb_date_struct date); /*! Test a `duckdb_date` to see if it is a finite value. -* date: The date object, as obtained from a `DUCKDB_TYPE_DATE` column. -* returns: True if the date is finite, false if it is ±infinity. +* @param date The date object, as obtained from a `DUCKDB_TYPE_DATE` column. +* @return True if the date is finite, false if it is ±infinity. */ DUCKDB_API bool duckdb_is_finite_date(duckdb_date date); /*! Decompose a `duckdb_time` object into hour, minute, second and microsecond (stored as `duckdb_time_struct`). -* time: The time object, as obtained from a `DUCKDB_TYPE_TIME` column. -* returns: The `duckdb_time_struct` with the decomposed elements. +* @param time The time object, as obtained from a `DUCKDB_TYPE_TIME` column. +* @return The `duckdb_time_struct` with the decomposed elements. */ DUCKDB_API duckdb_time_struct duckdb_from_time(duckdb_time time); /*! Create a `duckdb_time_tz` object from micros and a timezone offset. -* micros: The microsecond component of the time. -* offset: The timezone offset component of the time. -* returns: The `duckdb_time_tz` element. +* @param micros The microsecond component of the time. +* @param offset The timezone offset component of the time. +* @return The `duckdb_time_tz` element. */ DUCKDB_API duckdb_time_tz duckdb_create_time_tz(int64_t micros, int32_t offset); @@ -1123,41 +1292,39 @@ Decompose a TIME_TZ objects into micros and a timezone offset. Use `duckdb_from_time` to further decompose the micros into hour, minute, second and microsecond. -* micros: The time object, as obtained from a `DUCKDB_TYPE_TIME_TZ` column. -* out_micros: The microsecond component of the time. -* out_offset: The timezone offset component of the time. +* @param micros The time object, as obtained from a `DUCKDB_TYPE_TIME_TZ` column. */ DUCKDB_API duckdb_time_tz_struct duckdb_from_time_tz(duckdb_time_tz micros); /*! Re-compose a `duckdb_time` from hour, minute, second and microsecond (`duckdb_time_struct`). -* time: The hour, minute, second and microsecond in a `duckdb_time_struct`. -* returns: The `duckdb_time` element. +* @param time The hour, minute, second and microsecond in a `duckdb_time_struct`. +* @return The `duckdb_time` element. */ DUCKDB_API duckdb_time duckdb_to_time(duckdb_time_struct time); /*! Decompose a `duckdb_timestamp` object into a `duckdb_timestamp_struct`. -* ts: The ts object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. -* returns: The `duckdb_timestamp_struct` with the decomposed elements. +* @param ts The ts object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. +* @return The `duckdb_timestamp_struct` with the decomposed elements. */ DUCKDB_API duckdb_timestamp_struct duckdb_from_timestamp(duckdb_timestamp ts); /*! Re-compose a `duckdb_timestamp` from a duckdb_timestamp_struct. -* ts: The de-composed elements in a `duckdb_timestamp_struct`. -* returns: The `duckdb_timestamp` element. +* @param ts The de-composed elements in a `duckdb_timestamp_struct`. +* @return The `duckdb_timestamp` element. */ DUCKDB_API duckdb_timestamp duckdb_to_timestamp(duckdb_timestamp_struct ts); /*! Test a `duckdb_timestamp` to see if it is a finite value. -* ts: The timestamp object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. -* returns: True if the timestamp is finite, false if it is ±infinity. +* @param ts The timestamp object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. +* @return True if the timestamp is finite, false if it is ±infinity. */ DUCKDB_API bool duckdb_is_finite_timestamp(duckdb_timestamp ts); @@ -1168,8 +1335,8 @@ DUCKDB_API bool duckdb_is_finite_timestamp(duckdb_timestamp ts); /*! Converts a duckdb_hugeint object (as obtained from a `DUCKDB_TYPE_HUGEINT` column) into a double. -* val: The hugeint value. -* returns: The converted `double` element. +* @param val The hugeint value. +* @return The converted `double` element. */ DUCKDB_API double duckdb_hugeint_to_double(duckdb_hugeint val); @@ -1178,8 +1345,8 @@ Converts a double value to a duckdb_hugeint object. If the conversion fails because the double value is too big the result will be 0. -* val: The double value. -* returns: The converted `duckdb_hugeint` element. +* @param val The double value. +* @return The converted `duckdb_hugeint` element. */ DUCKDB_API duckdb_hugeint duckdb_double_to_hugeint(double val); @@ -1190,8 +1357,8 @@ DUCKDB_API duckdb_hugeint duckdb_double_to_hugeint(double val); /*! Converts a duckdb_uhugeint object (as obtained from a `DUCKDB_TYPE_UHUGEINT` column) into a double. -* val: The uhugeint value. -* returns: The converted `double` element. +* @param val The uhugeint value. +* @return The converted `double` element. */ DUCKDB_API double duckdb_uhugeint_to_double(duckdb_uhugeint val); @@ -1200,8 +1367,8 @@ Converts a double value to a duckdb_uhugeint object. If the conversion fails because the double value is too big the result will be 0. -* val: The double value. -* returns: The converted `duckdb_uhugeint` element. +* @param val The double value. +* @return The converted `duckdb_uhugeint` element. */ DUCKDB_API duckdb_uhugeint duckdb_double_to_uhugeint(double val); @@ -1214,16 +1381,16 @@ Converts a double value to a duckdb_decimal object. If the conversion fails because the double value is too big, or the width/scale are invalid the result will be 0. -* val: The double value. -* returns: The converted `duckdb_decimal` element. +* @param val The double value. +* @return The converted `duckdb_decimal` element. */ DUCKDB_API duckdb_decimal duckdb_double_to_decimal(double val, uint8_t width, uint8_t scale); /*! Converts a duckdb_decimal object (as obtained from a `DUCKDB_TYPE_DECIMAL` column) into a double. -* val: The decimal value. -* returns: The converted `double` element. +* @param val The decimal value. +* @return The converted `double` element. */ DUCKDB_API double duckdb_decimal_to_double(duckdb_decimal val); @@ -1240,7 +1407,6 @@ DUCKDB_API double duckdb_decimal_to_double(duckdb_decimal val); // SELECT * FROM tbl WHERE id=? // Or a query with multiple parameters: // SELECT * FROM tbl WHERE id=$1 OR name=$2 - /*! Create a prepared statement object from a query. @@ -1249,10 +1415,10 @@ Note that after calling `duckdb_prepare`, the prepared statement should always b If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. -* connection: The connection object -* query: The SQL query to prepare -* out_prepared_statement: The resulting prepared statement object -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param connection The connection object +* @param query The SQL query to prepare +* @param out_prepared_statement The resulting prepared statement object +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, duckdb_prepared_statement *out_prepared_statement); @@ -1260,7 +1426,7 @@ DUCKDB_API duckdb_state duckdb_prepare(duckdb_connection connection, const char /*! Closes the prepared statement and de-allocates all memory allocated for the statement. -* prepared_statement: The prepared statement to destroy. +* @param prepared_statement The prepared statement to destroy. */ DUCKDB_API void duckdb_destroy_prepare(duckdb_prepared_statement *prepared_statement); @@ -1270,8 +1436,8 @@ If the prepared statement has no error message, this returns `nullptr` instead. The error message should not be freed. It will be de-allocated when `duckdb_destroy_prepare` is called. -* prepared_statement: The prepared statement to obtain the error from. -* returns: The error message, or `nullptr` if there is none. +* @param prepared_statement The prepared statement to obtain the error from. +* @return The error message, or `nullptr` if there is none. */ DUCKDB_API const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement); @@ -1280,7 +1446,7 @@ Returns the number of parameters that can be provided to the given prepared stat Returns 0 if the query was not successfully prepared. -* prepared_statement: The prepared statement to obtain the number of parameters for. +* @param prepared_statement The prepared statement to obtain the number of parameters for. */ DUCKDB_API idx_t duckdb_nparams(duckdb_prepared_statement prepared_statement); @@ -1290,7 +1456,7 @@ The returned string should be freed using `duckdb_free`. Returns NULL if the index is out of range for the provided prepared statement. -* prepared_statement: The prepared statement for which to get the parameter name from. +* @param prepared_statement The prepared statement for which to get the parameter name from. */ DUCKDB_API const char *duckdb_parameter_name(duckdb_prepared_statement prepared_statement, idx_t index); @@ -1299,9 +1465,9 @@ Returns the parameter type for the parameter at the given index. Returns `DUCKDB_TYPE_INVALID` if the parameter index is out of range or the statement was not successfully prepared. -* prepared_statement: The prepared statement. -* param_idx: The parameter index. -* returns: The parameter type +* @param prepared_statement The prepared statement. +* @param param_idx The parameter index. +* @return The parameter type */ DUCKDB_API duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_t param_idx); @@ -1313,13 +1479,13 @@ DUCKDB_API duckdb_state duckdb_clear_bindings(duckdb_prepared_statement prepared /*! Returns the statement type of the statement to be executed - * statement: The prepared statement. - * returns: duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID - */ +* @param statement The prepared statement. +* @return duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID +*/ DUCKDB_API duckdb_statement_type duckdb_prepared_statement_type(duckdb_prepared_statement statement); //===--------------------------------------------------------------------===// -// Bind Values to Prepared Statements +// Bind Values To Prepared Statements //===--------------------------------------------------------------------===// /*! @@ -1364,11 +1530,13 @@ Binds a duckdb_hugeint value to the prepared statement at the specified index. */ DUCKDB_API duckdb_state duckdb_bind_hugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_hugeint val); + /*! Binds an duckdb_uhugeint value to the prepared statement at the specified index. */ DUCKDB_API duckdb_state duckdb_bind_uhugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_uhugeint val); + /*! Binds a duckdb_decimal value to the prepared statement at the specified index. */ @@ -1423,6 +1591,12 @@ Binds a duckdb_timestamp value to the prepared statement at the specified index. DUCKDB_API duckdb_state duckdb_bind_timestamp(duckdb_prepared_statement prepared_statement, idx_t param_idx, duckdb_timestamp val); +/*! +Binds a duckdb_timestamp value to the prepared statement at the specified index. +*/ +DUCKDB_API duckdb_state duckdb_bind_timestamp_tz(duckdb_prepared_statement prepared_statement, idx_t param_idx, + duckdb_timestamp val); + /*! Binds a duckdb_interval value to the prepared statement at the specified index. */ @@ -1464,9 +1638,9 @@ between calls to this function. Note that the result must be freed with `duckdb_destroy_result`. -* prepared_statement: The prepared statement to execute. -* out_result: The query result. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param prepared_statement The prepared statement to execute. +* @param out_result The query result. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_execute_prepared(duckdb_prepared_statement prepared_statement, duckdb_result *out_result); @@ -1483,9 +1657,9 @@ between calls to this function. Note that the result must be freed with `duckdb_destroy_result`. -* prepared_statement: The prepared statement to execute. -* out_result: The query result. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param prepared_statement The prepared statement to execute. +* @param out_result The query result. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_execute_prepared_streaming(duckdb_prepared_statement prepared_statement, duckdb_result *out_result); @@ -1496,7 +1670,6 @@ DUCKDB_API duckdb_state duckdb_execute_prepared_streaming(duckdb_prepared_statem //===--------------------------------------------------------------------===// // A query string can be extracted into multiple SQL statements. Each statement can be prepared and executed separately. - /*! Extract all statements from a query. Note that after calling `duckdb_extract_statements`, the extracted statements should always be destroyed using @@ -1504,10 +1677,10 @@ Note that after calling `duckdb_extract_statements`, the extracted statements sh If the extract fails, `duckdb_extract_statements_error` can be called to obtain the reason why the extract failed. -* connection: The connection object -* query: The SQL query to extract -* out_extracted_statements: The resulting extracted statements object -* returns: The number of extracted statements or 0 on failure. +* @param connection The connection object +* @param query The SQL query to extract +* @param out_extracted_statements The resulting extracted statements object +* @return The number of extracted statements or 0 on failure. */ DUCKDB_API idx_t duckdb_extract_statements(duckdb_connection connection, const char *query, duckdb_extracted_statements *out_extracted_statements); @@ -1519,28 +1692,29 @@ Note that after calling `duckdb_prepare_extracted_statement`, the prepared state If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. -* connection: The connection object -* extracted_statements: The extracted statements object -* index: The index of the extracted statement to prepare -* out_prepared_statement: The resulting prepared statement object -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param connection The connection object +* @param extracted_statements The extracted statements object +* @param index The index of the extracted statement to prepare +* @param out_prepared_statement The resulting prepared statement object +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_prepare_extracted_statement(duckdb_connection connection, duckdb_extracted_statements extracted_statements, idx_t index, duckdb_prepared_statement *out_prepared_statement); + /*! Returns the error message contained within the extracted statements. The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_extracted` is called. -* result: The extracted statements to fetch the error from. -* returns: The error of the extracted statements. +* @param extracted_statements The extracted statements to fetch the error from. +* @return The error of the extracted statements. */ DUCKDB_API const char *duckdb_extract_statements_error(duckdb_extracted_statements extracted_statements); /*! De-allocates all memory allocated for the extracted statements. -* extracted_statements: The extracted statements to destroy. +* @param extracted_statements The extracted statements to destroy. */ DUCKDB_API void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements); @@ -1556,12 +1730,13 @@ The pending result can be used to incrementally execute a query, returning contr Note that after calling `duckdb_pending_prepared`, the pending result should always be destroyed using `duckdb_destroy_pending`, even if this function returns DuckDBError. -* prepared_statement: The prepared statement to execute. -* out_result: The pending query result. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param prepared_statement The prepared statement to execute. +* @param out_result The pending query result. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_pending_prepared(duckdb_prepared_statement prepared_statement, duckdb_pending_result *out_result); + #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1573,9 +1748,9 @@ The pending result represents an intermediate structure for a query that is not Note that after calling `duckdb_pending_prepared_streaming`, the pending result should always be destroyed using `duckdb_destroy_pending`, even if this function returns DuckDBError. -* prepared_statement: The prepared statement to execute. -* out_result: The pending query result. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param prepared_statement The prepared statement to execute. +* @param out_result The pending query result. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statement prepared_statement, duckdb_pending_result *out_result); @@ -1584,7 +1759,7 @@ DUCKDB_API duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statem /*! Closes the pending result and de-allocates all memory allocated for the result. -* pending_result: The pending result to destroy. +* @param pending_result The pending result to destroy. */ DUCKDB_API void duckdb_destroy_pending(duckdb_pending_result *pending_result); @@ -1593,8 +1768,8 @@ Returns the error message contained within the pending result. The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_pending` is called. -* result: The pending result to fetch the error from. -* returns: The error of the pending result. +* @param pending_result The pending result to fetch the error from. +* @return The error of the pending result. */ DUCKDB_API const char *duckdb_pending_error(duckdb_pending_result pending_result); @@ -1607,8 +1782,8 @@ If this returns DUCKDB_PENDING_ERROR, an error occurred during execution. The error message can be obtained by calling duckdb_pending_error on the pending_result. -* pending_result: The pending result to execute a task within. -* returns: The state of the pending result after the execution. +* @param pending_result The pending result to execute a task within. +* @return The state of the pending result after the execution. */ DUCKDB_API duckdb_pending_state duckdb_pending_execute_task(duckdb_pending_result pending_result); @@ -1619,8 +1794,8 @@ If this returns DUCKDB_PENDING_ERROR, an error occurred during execution. The error message can be obtained by calling duckdb_pending_error on the pending_result. -* pending_result: The pending result. -* returns: The state of the pending result. +* @param pending_result The pending result. +* @return The state of the pending result. */ DUCKDB_API duckdb_pending_state duckdb_pending_execute_check_state(duckdb_pending_result pending_result); @@ -1632,9 +1807,9 @@ Otherwise, all remaining tasks must be executed first. Note that the result must be freed with `duckdb_destroy_result`. -* pending_result: The pending result to execute. -* out_result: The result object. -* returns: `DuckDBSuccess` on success or `DuckDBError` on failure. +* @param pending_result The pending result to execute. +* @param out_result The result object. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. */ DUCKDB_API duckdb_state duckdb_execute_pending(duckdb_pending_result pending_result, duckdb_result *out_result); @@ -1642,8 +1817,8 @@ DUCKDB_API duckdb_state duckdb_execute_pending(duckdb_pending_result pending_res Returns whether a duckdb_pending_state is finished executing. For example if `pending_state` is DUCKDB_PENDING_RESULT_READY, this function will return true. -* pending_state: The pending state on which to decide whether to finish execution. -* returns: Boolean indicating pending execution should be considered finished. +* @param pending_state The pending state on which to decide whether to finish execution. +* @return Boolean indicating pending execution should be considered finished. */ DUCKDB_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state); @@ -1654,447 +1829,794 @@ DUCKDB_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pendin /*! Destroys the value and de-allocates all memory allocated for that type. -* value: The value to destroy. +* @param value The value to destroy. */ DUCKDB_API void duckdb_destroy_value(duckdb_value *value); /*! Creates a value from a null-terminated string -* value: The null-terminated string -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param text The null-terminated string +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ DUCKDB_API duckdb_value duckdb_create_varchar(const char *text); /*! Creates a value from a string -* value: The text -* length: The length of the text -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param text The text +* @param length The length of the text +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ DUCKDB_API duckdb_value duckdb_create_varchar_length(const char *text, idx_t length); /*! -Creates a value from an int64 +Creates a value from a boolean -* value: The bigint value -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param input The boolean value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_value duckdb_create_int64(int64_t val); +DUCKDB_API duckdb_value duckdb_create_bool(bool input); /*! -Creates a struct value from a type and an array of values +Creates a value from a int8_t (a tinyint) -* type: The type of the struct -* values: The values for the struct fields -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param input The tinyint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_value duckdb_create_struct_value(duckdb_logical_type type, duckdb_value *values); +DUCKDB_API duckdb_value duckdb_create_int8(int8_t input); /*! -Creates a list value from a type and an array of values of length `value_count` +Creates a value from a uint8_t (a utinyint) -* type: The type of the list -* values: The values for the list -* value_count: The number of values in the list -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param input The utinyint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_value duckdb_create_list_value(duckdb_logical_type type, duckdb_value *values, idx_t value_count); +DUCKDB_API duckdb_value duckdb_create_uint8(uint8_t input); /*! -Creates a array value from a type and an array of values of length `value_count` +Creates a value from a int16_t (a smallint) -* type: The type of the array -* values: The values for the array -* value_count: The number of values in the array -* returns: The value. This must be destroyed with `duckdb_destroy_value`. +* @param input The smallint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_value duckdb_create_array_value(duckdb_logical_type type, duckdb_value *values, idx_t value_count); +DUCKDB_API duckdb_value duckdb_create_int16(int16_t input); /*! -Obtains a string representation of the given value. -The result must be destroyed with `duckdb_free`. +Creates a value from a uint16_t (a usmallint) -* value: The value -* returns: The string value. This must be destroyed with `duckdb_free`. +* @param input The usmallint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API char *duckdb_get_varchar(duckdb_value value); +DUCKDB_API duckdb_value duckdb_create_uint16(uint16_t input); /*! -Obtains an int64 of the given value. +Creates a value from a int32_t (an integer) -* value: The value -* returns: The int64 value, or 0 if no conversion is possible +* @param input The integer value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API int64_t duckdb_get_int64(duckdb_value value); - -//===--------------------------------------------------------------------===// -// Logical Type Interface -//===--------------------------------------------------------------------===// +DUCKDB_API duckdb_value duckdb_create_int32(int32_t input); /*! -Creates a `duckdb_logical_type` from a standard primitive type. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a uint32_t (a uinteger) -This should not be used with `DUCKDB_TYPE_DECIMAL`. - -* type: The primitive type to create. -* returns: The logical type. +* @param input The uinteger value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_logical_type(duckdb_type type); +DUCKDB_API duckdb_value duckdb_create_uint32(uint32_t input); /*! -Returns the alias of a duckdb_logical_type, if one is set, else `NULL`. -The result must be destroyed with `duckdb_free`. +Creates a value from a uint64_t (a ubigint) -* type: The logical type to return the alias of -* returns: The alias or `NULL` - */ -DUCKDB_API char *duckdb_logical_type_get_alias(duckdb_logical_type type); +* @param input The ubigint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_API duckdb_value duckdb_create_uint64(uint64_t input); /*! -Creates a list type from its child type. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from an int64 -* type: The child type of list type to create. -* returns: The logical type. +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_list_type(duckdb_logical_type type); +DUCKDB_API duckdb_value duckdb_create_int64(int64_t val); /*! -Creates a array type from its child type. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a hugeint -* type: The child type of array type to create. -* array_size: The number of elements in the array. -* returns: The logical type. +* @param input The hugeint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_array_type(duckdb_logical_type type, idx_t array_size); +DUCKDB_API duckdb_value duckdb_create_hugeint(duckdb_hugeint input); /*! -Creates a map type from its key type and value type. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a uhugeint -* type: The key type and value type of map type to create. -* returns: The logical type. +* @param input The uhugeint value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_map_type(duckdb_logical_type key_type, duckdb_logical_type value_type); +DUCKDB_API duckdb_value duckdb_create_uhugeint(duckdb_uhugeint input); /*! -Creates a UNION type from the passed types array. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a float -* types: The array of types that the union should consist of. -* type_amount: The size of the types array. -* returns: The logical type. +* @param input The float value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_union_type(duckdb_logical_type *member_types, const char **member_names, - idx_t member_count); +DUCKDB_API duckdb_value duckdb_create_float(float input); /*! -Creates a STRUCT type from the passed member name and type arrays. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a double -* member_types: The array of types that the struct should consist of. -* member_names: The array of names that the struct should consist of. -* member_count: The number of members that were specified for both arrays. -* returns: The logical type. +* @param input The double value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_struct_type(duckdb_logical_type *member_types, const char **member_names, - idx_t member_count); +DUCKDB_API duckdb_value duckdb_create_double(double input); /*! -Creates an ENUM type from the passed member name array. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a date -* enum_name: The name of the enum. -* member_names: The array of names that the enum should consist of. -* member_count: The number of elements that were specified in the array. -* returns: The logical type. +* @param input The date value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_enum_type(const char **member_names, idx_t member_count); +DUCKDB_API duckdb_value duckdb_create_date(duckdb_date input); /*! -Creates a `duckdb_logical_type` of type decimal with the specified width and scale. -The resulting type should be destroyed with `duckdb_destroy_logical_type`. +Creates a value from a time -* width: The width of the decimal type -* scale: The scale of the decimal type -* returns: The logical type. +* @param input The time value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale); +DUCKDB_API duckdb_value duckdb_create_time(duckdb_time input); /*! -Retrieves the enum type class of a `duckdb_logical_type`. +Creates a value from a time_tz. +Not to be confused with `duckdb_create_time_tz`, which creates a duckdb_time_tz_t. -* type: The logical type object -* returns: The type id +* @param value The time_tz value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_type duckdb_get_type_id(duckdb_logical_type type); +DUCKDB_API duckdb_value duckdb_create_time_tz_value(duckdb_time_tz value); /*! -Retrieves the width of a decimal type. +Creates a value from a timestamp -* type: The logical type object -* returns: The width of the decimal type +* @param input The timestamp value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API uint8_t duckdb_decimal_width(duckdb_logical_type type); +DUCKDB_API duckdb_value duckdb_create_timestamp(duckdb_timestamp input); /*! -Retrieves the scale of a decimal type. +Creates a value from an interval -* type: The logical type object -* returns: The scale of the decimal type +* @param input The interval value +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API uint8_t duckdb_decimal_scale(duckdb_logical_type type); +DUCKDB_API duckdb_value duckdb_create_interval(duckdb_interval input); /*! -Retrieves the internal storage type of a decimal type. +Creates a value from a blob -* type: The logical type object -* returns: The internal type of the decimal type +* @param data The blob data +* @param length The length of the blob data +* @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_API duckdb_type duckdb_decimal_internal_type(duckdb_logical_type type); +DUCKDB_API duckdb_value duckdb_create_blob(const uint8_t *data, idx_t length); /*! -Retrieves the internal storage type of an enum type. +Returns the boolean value of the given value. -* type: The logical type object -* returns: The internal type of the enum type +* @param val A duckdb_value containing a boolean +* @return A boolean, or false if the value cannot be converted */ -DUCKDB_API duckdb_type duckdb_enum_internal_type(duckdb_logical_type type); +DUCKDB_API bool duckdb_get_bool(duckdb_value val); /*! -Retrieves the dictionary size of the enum type. +Returns the int8_t value of the given value. -* type: The logical type object -* returns: The dictionary size of the enum type +* @param val A duckdb_value containing a tinyint +* @return A int8_t, or MinValue if the value cannot be converted */ -DUCKDB_API uint32_t duckdb_enum_dictionary_size(duckdb_logical_type type); +DUCKDB_API int8_t duckdb_get_int8(duckdb_value val); /*! -Retrieves the dictionary value at the specified position from the enum. +Returns the uint8_t value of the given value. -The result must be freed with `duckdb_free`. - -* type: The logical type object -* index: The index in the dictionary -* returns: The string value of the enum type. Must be freed with `duckdb_free`. +* @param val A duckdb_value containing a utinyint +* @return A uint8_t, or MinValue if the value cannot be converted */ -DUCKDB_API char *duckdb_enum_dictionary_value(duckdb_logical_type type, idx_t index); +DUCKDB_API uint8_t duckdb_get_uint8(duckdb_value val); /*! -Retrieves the child type of the given list type. - -The result must be freed with `duckdb_destroy_logical_type`. +Returns the int16_t value of the given value. -* type: The logical type object -* returns: The child type of the list type. Must be destroyed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a smallint +* @return A int16_t, or MinValue if the value cannot be converted */ -DUCKDB_API duckdb_logical_type duckdb_list_type_child_type(duckdb_logical_type type); +DUCKDB_API int16_t duckdb_get_int16(duckdb_value val); /*! -Retrieves the child type of the given array type. - -The result must be freed with `duckdb_destroy_logical_type`. +Returns the uint16_t value of the given value. -* type: The logical type object -* returns: The child type of the array type. Must be destroyed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a usmallint +* @return A uint16_t, or MinValue if the value cannot be converted */ -DUCKDB_API duckdb_logical_type duckdb_array_type_child_type(duckdb_logical_type type); +DUCKDB_API uint16_t duckdb_get_uint16(duckdb_value val); /*! -Retrieves the array size of the given array type. +Returns the int32_t value of the given value. -* type: The logical type object -* returns: The fixed number of elements the values of this array type can store. +* @param val A duckdb_value containing a integer +* @return A int32_t, or MinValue if the value cannot be converted */ -DUCKDB_API idx_t duckdb_array_type_array_size(duckdb_logical_type type); +DUCKDB_API int32_t duckdb_get_int32(duckdb_value val); /*! -Retrieves the key type of the given map type. +Returns the uint32_t value of the given value. -The result must be freed with `duckdb_destroy_logical_type`. - -* type: The logical type object -* returns: The key type of the map type. Must be destroyed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a uinteger +* @return A uint32_t, or MinValue if the value cannot be converted */ -DUCKDB_API duckdb_logical_type duckdb_map_type_key_type(duckdb_logical_type type); +DUCKDB_API uint32_t duckdb_get_uint32(duckdb_value val); /*! -Retrieves the value type of the given map type. +Returns the int64_t value of the given value. -The result must be freed with `duckdb_destroy_logical_type`. - -* type: The logical type object -* returns: The value type of the map type. Must be destroyed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a bigint +* @return A int64_t, or MinValue if the value cannot be converted */ -DUCKDB_API duckdb_logical_type duckdb_map_type_value_type(duckdb_logical_type type); +DUCKDB_API int64_t duckdb_get_int64(duckdb_value val); /*! -Returns the number of children of a struct type. +Returns the uint64_t value of the given value. -* type: The logical type object -* returns: The number of children of a struct type. +* @param val A duckdb_value containing a ubigint +* @return A uint64_t, or MinValue if the value cannot be converted */ -DUCKDB_API idx_t duckdb_struct_type_child_count(duckdb_logical_type type); +DUCKDB_API uint64_t duckdb_get_uint64(duckdb_value val); /*! -Retrieves the name of the struct child. - -The result must be freed with `duckdb_free`. +Returns the hugeint value of the given value. -* type: The logical type object -* index: The child index -* returns: The name of the struct type. Must be freed with `duckdb_free`. +* @param val A duckdb_value containing a hugeint +* @return A duckdb_hugeint, or MinValue if the value cannot be converted */ -DUCKDB_API char *duckdb_struct_type_child_name(duckdb_logical_type type, idx_t index); +DUCKDB_API duckdb_hugeint duckdb_get_hugeint(duckdb_value val); /*! -Retrieves the child type of the given struct type at the specified index. - -The result must be freed with `duckdb_destroy_logical_type`. +Returns the uhugeint value of the given value. -* type: The logical type object -* index: The child index -* returns: The child type of the struct type. Must be destroyed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a uhugeint +* @return A duckdb_uhugeint, or MinValue if the value cannot be converted */ -DUCKDB_API duckdb_logical_type duckdb_struct_type_child_type(duckdb_logical_type type, idx_t index); +DUCKDB_API duckdb_uhugeint duckdb_get_uhugeint(duckdb_value val); /*! -Returns the number of members that the union type has. +Returns the float value of the given value. -* type: The logical type (union) object -* returns: The number of members of a union type. +* @param val A duckdb_value containing a float +* @return A float, or NAN if the value cannot be converted */ -DUCKDB_API idx_t duckdb_union_type_member_count(duckdb_logical_type type); +DUCKDB_API float duckdb_get_float(duckdb_value val); /*! -Retrieves the name of the union member. +Returns the double value of the given value. -The result must be freed with `duckdb_free`. +* @param val A duckdb_value containing a double +* @return A double, or NAN if the value cannot be converted +*/ +DUCKDB_API double duckdb_get_double(duckdb_value val); + +/*! +Returns the date value of the given value. -* type: The logical type object -* index: The child index -* returns: The name of the union member. Must be freed with `duckdb_free`. +* @param val A duckdb_value containing a date +* @return A duckdb_date, or MinValue if the value cannot be converted */ -DUCKDB_API char *duckdb_union_type_member_name(duckdb_logical_type type, idx_t index); +DUCKDB_API duckdb_date duckdb_get_date(duckdb_value val); /*! -Retrieves the child type of the given union member at the specified index. +Returns the time value of the given value. -The result must be freed with `duckdb_destroy_logical_type`. +* @param val A duckdb_value containing a time +* @return A duckdb_time, or MinValue