diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 8958ca6fae62..54a7da63da0a 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -309,27 +309,31 @@ impl RecordBatch { return Err(ArrowError::InvalidArgumentError(err.to_string())); } - // function for comparing column type and field type - // return true if 2 types are not matched - let type_not_match = if options.match_field_names { - |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type - } else { - |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { - !col_type.equals_datatype(field_type) - } - }; + if !options.skip_schema_check { + // function for comparing column type and field type + // return true if 2 types are not matched + let type_not_match = if options.match_field_names { + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { + col_type != field_type + } + } else { + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { + !col_type.equals_datatype(field_type) + } + }; - // check that all columns match the schema - let not_match = columns - .iter() - .zip(schema.fields().iter()) - .map(|(col, field)| (col.data_type(), field.data_type())) - .enumerate() - .find(type_not_match); + // check that all columns match the schema + let not_match = columns + .iter() + .zip(schema.fields().iter()) + .map(|(col, field)| (col.data_type(), field.data_type())) + .enumerate() + .find(type_not_match); - if let Some((i, (col_type, field_type))) = not_match { - return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); + if let Some((i, (col_type, field_type))) = not_match { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); + } } Ok(RecordBatch { @@ -390,6 +394,7 @@ impl RecordBatch { &RecordBatchOptions { match_field_names: true, row_count: Some(self.row_count), + skip_schema_check: false, }, ) } @@ -631,6 +636,13 @@ pub struct RecordBatchOptions { /// Optional row count, useful for specifying a row count for a RecordBatch with no columns pub row_count: Option, + + /// Option to skip schema checking when creating new record batches. This is intended for + /// cases where the schema has already been checked or where more flexibility is required + /// in downstream projects, such as allowing either Utf8 or Dictionary<_, Utf8> for a + /// schema with type Utf8. Using this option is likely to break compatibility with arrow-rs + /// kernels that operate on RecordBatch so should be used with caution. + pub skip_schema_check: bool, } impl RecordBatchOptions { @@ -639,6 +651,7 @@ impl RecordBatchOptions { Self { match_field_names: true, row_count: None, + skip_schema_check: false, } } /// Sets the row_count of RecordBatchOptions and returns self @@ -651,6 +664,11 @@ impl RecordBatchOptions { self.match_field_names = match_field_names; self } + /// Sets the skip_schema_check of RecordBatchOptions and returns self + pub fn with_skip_schema_check(mut self, skip_schema_check: bool) -> Self { + self.skip_schema_check = skip_schema_check; + self + } } impl Default for RecordBatchOptions { fn default() -> Self { @@ -942,6 +960,18 @@ mod tests { assert!(batch.is_err()); } + #[test] + fn create_record_batch_schema_mismatch_allowed() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let options = RecordBatchOptions::new().with_skip_schema_check(true); + let batch = + RecordBatch::try_new_with_options(Arc::new(schema), vec![Arc::new(a)], &options); + assert!(batch.is_ok()); + } + #[test] fn create_record_batch_field_name_mismatch() { let fields = vec![ @@ -982,6 +1012,7 @@ mod tests { let options = RecordBatchOptions { match_field_names: false, row_count: None, + skip_schema_check: false, }; let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); assert!(batch.is_ok()); @@ -1226,6 +1257,7 @@ mod tests { &RecordBatchOptions { match_field_names: true, row_count: Some(3), + skip_schema_check: false, }, ) .expect("valid conversion");