diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 6d5299f71647..ca154925dfdd 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -24,6 +24,8 @@ use arrow::{ use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow_array::{Array, ArrayRef, RecordBatch}; use arrow_schema::SchemaRef; @@ -36,8 +38,6 @@ use datafusion_execution::{ use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; - /// Global TopK /// /// # Background @@ -575,7 +575,7 @@ impl RecordBatchStore { pub fn insert(&mut self, entry: RecordBatchEntry) { // uses of 0 means that none of the rows in the batch were stored in the topk if entry.uses > 0 { - self.batches_size += entry.batch.get_array_memory_size(); + self.batches_size += get_record_batch_memory_size(&entry.batch); self.batches.insert(entry.id, entry); } } @@ -630,7 +630,7 @@ impl RecordBatchStore { let old_entry = self.batches.remove(&id).unwrap(); self.batches_size = self .batches_size - .checked_sub(old_entry.batch.get_array_memory_size()) + .checked_sub(get_record_batch_memory_size(&old_entry.batch)) .unwrap(); } } @@ -643,3 +643,44 @@ impl RecordBatchStore { + self.batches_size } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::Float64Array; + + /// This test ensures the size calculation is correct for RecordBatches with multiple columns. + #[test] + fn test_record_batch_store_size() { + // given + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema)); + let int_array = + Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20 + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40 + + let record_batch_entry = RecordBatchEntry { + id: 0, + batch: RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(), + uses: 1, + }; + + // when insert record batch entry + record_batch_store.insert(record_batch_entry); + assert_eq!(record_batch_store.batches_size, 60); + + // when unuse record batch entry + record_batch_store.unuse(0); + assert_eq!(record_batch_store.batches_size, 0); + } +}