diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index f55053e5034..9bdb80ef31c 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -779,42 +779,66 @@ fn parse( match key_type.as_ref() { DataType::Int8 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int16 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int32 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::Int64 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt8 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt16 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt32 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), DataType::UInt64 => Ok(Arc::new( rows.iter() - .map(|row| row.get(i)) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::>(), ) as ArrayRef), _ => Err(ArrowError::ParseError(format!( @@ -1475,6 +1499,40 @@ mod tests { assert_eq!(strings.value(29), "Uckfield, East Sussex, UK"); } + #[test] + fn test_csv_with_nullable_dictionary() { + let offset_type = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + ]; + for data_type in offset_type { + let file = File::open("test/data/dictionary_nullable_test.csv").unwrap(); + let dictionary_type = + DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", dictionary_type.clone(), true), + ])); + + let mut csv = ReaderBuilder::new(schema) + .build(file.try_clone().unwrap()) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + + let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap(); + assert!(!names.is_null(2)); + assert!(names.is_null(1)); + } + } #[test] fn test_nulls() { let schema = Arc::new(Schema::new(vec![ diff --git a/arrow-csv/test/data/dictionary_nullable_test.csv b/arrow-csv/test/data/dictionary_nullable_test.csv new file mode 100644 index 00000000000..c9ada5293b7 --- /dev/null +++ b/arrow-csv/test/data/dictionary_nullable_test.csv @@ -0,0 +1,3 @@ +id,name +1, +2,bob