diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 4d7d60e10cf6..09087ca31958 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -231,6 +231,7 @@ pub struct Format { quote: Option, terminator: Option, null_regex: NullRegex, + truncated_rows: bool, } impl Format { @@ -265,6 +266,17 @@ impl Format { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.truncated_rows = allow; + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -329,6 +341,7 @@ impl Format { fn build_reader(&self, reader: R) -> csv::Reader { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(self.header); + builder.flexible(self.truncated_rows); if let Some(c) = self.delimiter { builder.delimiter(c); @@ -1121,6 +1134,17 @@ impl ReaderBuilder { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.format.truncated_rows = allow; + self + } + /// Create a new `Reader` from a non-buffered reader /// /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional @@ -1140,7 +1164,11 @@ impl ReaderBuilder { /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream pub fn build_decoder(self) -> Decoder { let delimiter = self.format.build_parser(); - let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); + let record_decoder = RecordDecoder::new( + delimiter, + self.schema.fields().len(), + self.format.truncated_rows, + ); let header = self.format.header as usize; @@ -2164,6 +2192,133 @@ mod tests { assert!(c.is_null(3)); } + #[test] + fn test_truncated_rows() { + let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(batches.is_ok()); + let batch = batches.unwrap().into_iter().next().unwrap(); + // Empty rows are skipped by the underlying csv parser + assert_eq!(batch.num_rows(), 3); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(false) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"), + _ => false, + }); + } + + #[test] + fn test_truncated_rows_csv() { + let file = File::open("test/data/truncated_rows.csv").unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("Name", DataType::Utf8, true), + Field::new("Age", DataType::UInt32, true), + Field::new("Occupation", DataType::Utf8, true), + Field::new("DOB", DataType::Date32, true), + ])); + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_batch_size(24) + .with_truncated_rows(true); + let csv = reader.build(file).unwrap(); + let batches = csv.collect::, _>>().unwrap(); + + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 4); + let name = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let age = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let occupation = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let dob = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(name.value(0), "A1"); + assert_eq!(name.value(1), "B2"); + assert!(name.is_null(2)); + assert_eq!(name.value(3), "C3"); + assert_eq!(name.value(4), "D4"); + assert_eq!(name.value(5), "E5"); + + assert_eq!(age.value(0), 34); + assert_eq!(age.value(1), 29); + assert!(age.is_null(2)); + assert_eq!(age.value(3), 45); + assert!(age.is_null(4)); + assert_eq!(age.value(5), 31); + + assert_eq!(occupation.value(0), "Engineer"); + assert_eq!(occupation.value(1), "Doctor"); + assert!(occupation.is_null(2)); + assert_eq!(occupation.value(3), "Artist"); + assert!(occupation.is_null(4)); + assert!(occupation.is_null(5)); + + assert_eq!(dob.value(0), 5675); + assert!(dob.is_null(1)); + assert!(dob.is_null(2)); + assert_eq!(dob.value(3), -1858); + assert!(dob.is_null(4)); + assert!(dob.is_null(5)); + } + + #[test] + fn test_truncated_rows_not_nullable_error() { + let data = "a,b,c\n1,2,3\n4,5"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::InvalidArgumentError(e)) => + e.to_string().contains("contains null values"), + _ => false, + }); + } + #[test] fn test_buffered() { let tests = [ diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs index 877cfb3ee653..a07fc9c94ffa 100644 --- a/arrow-csv/src/reader/records.rs +++ b/arrow-csv/src/reader/records.rs @@ -56,10 +56,16 @@ pub struct RecordDecoder { /// /// We track this independently of Vec to avoid re-zeroing memory data_len: usize, + + /// Whether rows with less than expected columns are considered valid + /// + /// Default value is false + /// When enabled fills in missing columns with null + truncated_rows: bool, } impl RecordDecoder { - pub fn new(delimiter: Reader, num_columns: usize) -> Self { + pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self { Self { delimiter, num_columns, @@ -70,6 +76,7 @@ impl RecordDecoder { data_len: 0, data: vec![], num_rows: 0, + truncated_rows, } } @@ -127,10 +134,19 @@ impl RecordDecoder { } ReadRecordResult::Record => { if self.current_field != self.num_columns { - return Err(ArrowError::CsvError(format!( - "incorrect number of fields for line {}, expected {} got {}", - self.line_number, self.num_columns, self.current_field - ))); + if self.truncated_rows && self.current_field < self.num_columns { + // If the number of fields is less than expected, pad with nulls + let fill_count = self.num_columns - self.current_field; + let fill_value = self.offsets[self.offsets_len - 1]; + self.offsets[self.offsets_len..self.offsets_len + fill_count] + .fill(fill_value); + self.offsets_len += fill_count; + } else { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); + } } read += 1; self.current_field = 0; @@ -299,7 +315,7 @@ mod tests { .into_iter(); let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes())); - let mut decoder = RecordDecoder::new(Reader::new(), 3); + let mut decoder = RecordDecoder::new(Reader::new(), 3, false); loop { let to_read = 3; @@ -333,7 +349,7 @@ mod tests { #[test] fn test_invalid_fields() { let csv = "a,b\nb,c\na\n"; - let mut decoder = RecordDecoder::new(Reader::new(), 2); + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string(); let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1"; @@ -341,7 +357,7 @@ mod tests { assert_eq!(err, expected); // Test with initial skip - let mut decoder = RecordDecoder::new(Reader::new(), 2); + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap(); assert_eq!(skipped, 1); decoder.clear(); @@ -354,9 +370,18 @@ mod tests { #[test] fn test_skip_insufficient_rows() { let csv = "a\nv\n"; - let mut decoder = RecordDecoder::new(Reader::new(), 1); + let mut decoder = RecordDecoder::new(Reader::new(), 1, false); let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap(); assert_eq!(read, 2); assert_eq!(bytes, csv.len()); } + + #[test] + fn test_truncated_rows() { + let csv = "a,b\nv\n,1\n,2\n,3\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2, true); + let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap(); + assert_eq!(read, 5); + assert_eq!(bytes, csv.len()); + } } diff --git a/arrow-csv/test/data/truncated_rows.csv b/arrow-csv/test/data/truncated_rows.csv new file mode 100644 index 000000000000..0b2af5740095 --- /dev/null +++ b/arrow-csv/test/data/truncated_rows.csv @@ -0,0 +1,8 @@ +Name,Age,Occupation,DOB +A1,34,Engineer,1985-07-16 +B2,29,Doctor +, +C3,45,Artist,1964-11-30 + +D4 +E5,31,,