Skip to content

Commit

Permalink
Fix negative decimal string (#5128)
Browse files Browse the repository at this point in the history
* Fix negative cases

* Fix

* Fix

* Fix clippy

* Update arrow-cast/src/cast.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* Update arrow-cast/src/cast.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

---------

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>
  • Loading branch information
viirya and tustvold authored Nov 28, 2023
1 parent e26fa4f commit 8a0b5cb
Showing 1 changed file with 88 additions and 2 deletions.
90 changes: 88 additions & 2 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2596,11 +2596,33 @@ where
)));
}

let integers = parts[0].trim_start_matches('0');
let (negative, first_part) = if parts[0].is_empty() {
(false, parts[0])
} else {
match parts[0].as_bytes()[0] {
b'-' => (true, &parts[0][1..]),
b'+' => (false, &parts[0][1..]),
_ => (false, parts[0]),
}
};

let integers = first_part.trim_start_matches('0');
let decimals = if parts.len() == 2 { parts[1] } else { "" };

if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid decimal format: {value_str:?}"
)));
}

if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid decimal format: {value_str:?}"
)));
}

// Adjust decimal based on scale
let number_decimals = if decimals.len() > scale {
let mut number_decimals = if decimals.len() > scale {
let decimal_number = i256::from_string(decimals).ok_or_else(|| {
ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
})?;
Expand Down Expand Up @@ -2640,6 +2662,10 @@ where
format!("{integers}{decimals}")
};

if negative {
number_decimals.insert(0, '-');
}

let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"Cannot convert {} to {}: Overflow",
Expand Down Expand Up @@ -8256,6 +8282,21 @@ mod tests {
assert_eq!("0.00", decimal_arr.value_as_string(10));
assert_eq!("0.00", decimal_arr.value_as_string(11));
assert!(decimal_arr.is_null(12));
assert_eq!("-1.23", decimal_arr.value_as_string(13));
assert_eq!("-1.24", decimal_arr.value_as_string(14));
assert_eq!("0.00", decimal_arr.value_as_string(15));
assert_eq!("-123.00", decimal_arr.value_as_string(16));
assert_eq!("-123.23", decimal_arr.value_as_string(17));
assert_eq!("-0.12", decimal_arr.value_as_string(18));
assert_eq!("1.23", decimal_arr.value_as_string(19));
assert_eq!("1.24", decimal_arr.value_as_string(20));
assert_eq!("0.00", decimal_arr.value_as_string(21));
assert_eq!("123.00", decimal_arr.value_as_string(22));
assert_eq!("123.23", decimal_arr.value_as_string(23));
assert_eq!("0.12", decimal_arr.value_as_string(24));
assert!(decimal_arr.is_null(25));
assert!(decimal_arr.is_null(26));
assert!(decimal_arr.is_null(27));

// Decimal256
let output_type = DataType::Decimal256(76, 3);
Expand All @@ -8277,6 +8318,21 @@ mod tests {
assert_eq!("0.000", decimal_arr.value_as_string(10));
assert_eq!("0.000", decimal_arr.value_as_string(11));
assert!(decimal_arr.is_null(12));
assert_eq!("-1.235", decimal_arr.value_as_string(13));
assert_eq!("-1.236", decimal_arr.value_as_string(14));
assert_eq!("0.000", decimal_arr.value_as_string(15));
assert_eq!("-123.000", decimal_arr.value_as_string(16));
assert_eq!("-123.234", decimal_arr.value_as_string(17));
assert_eq!("-0.123", decimal_arr.value_as_string(18));
assert_eq!("1.235", decimal_arr.value_as_string(19));
assert_eq!("1.236", decimal_arr.value_as_string(20));
assert_eq!("0.000", decimal_arr.value_as_string(21));
assert_eq!("123.000", decimal_arr.value_as_string(22));
assert_eq!("123.234", decimal_arr.value_as_string(23));
assert_eq!("0.123", decimal_arr.value_as_string(24));
assert!(decimal_arr.is_null(25));
assert!(decimal_arr.is_null(26));
assert!(decimal_arr.is_null(27));
}

#[test]
Expand All @@ -8295,6 +8351,21 @@ mod tests {
Some(""),
Some(" "),
None,
Some("-1.23499999"),
Some("-1.23599999"),
Some("-0.00001"),
Some("-123"),
Some("-123.234000"),
Some("-000.123"),
Some("+1.23499999"),
Some("+1.23599999"),
Some("+0.00001"),
Some("+123"),
Some("+123.234000"),
Some("+000.123"),
Some("1.-23499999"),
Some("-1.-23499999"),
Some("--1.23499999"),
]);
let array = Arc::new(str_array) as ArrayRef;

Expand All @@ -8317,6 +8388,21 @@ mod tests {
Some(""),
Some(" "),
None,
Some("-1.23499999"),
Some("-1.23599999"),
Some("-0.00001"),
Some("-123"),
Some("-123.234000"),
Some("-000.123"),
Some("+1.23499999"),
Some("+1.23599999"),
Some("+0.00001"),
Some("+123"),
Some("+123.234000"),
Some("+000.123"),
Some("1.-23499999"),
Some("-1.-23499999"),
Some("--1.23499999"),
]);
let array = Arc::new(str_array) as ArrayRef;

Expand Down

0 comments on commit 8a0b5cb

Please sign in to comment.