From eb7ab83aab6b54897b00da1eb739c863d61445a2 Mon Sep 17 00:00:00 2001 From: Himadri Pal Date: Thu, 12 Dec 2024 15:00:22 -0800 Subject: [PATCH] fix: decimal conversion looses value on lower precision (#6836) * decimal conversion looses value on lower precision, throws error now on overflow. * fix review comments and fix formatting. * for simple case of equal scale and bigger precision, no conversion needed. revert whitespace changes formatting check --------- Co-authored-by: himadripal --- arrow-cast/src/cast/decimal.rs | 49 +++++++++-------- arrow-cast/src/cast/mod.rs | 99 ++++++++++++++++++++++++++++++---- 2 files changed, 116 insertions(+), 32 deletions(-) diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 74cd2637e6a..ba82ca9040c 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -111,9 +111,13 @@ where O::Native::from_decimal(adjusted) }; - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + Ok(if cast_options.safe { + array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + } else { + array.try_unary(|x| { + f(x).ok_or_else(|| error(x)) + .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) + })? }) } @@ -137,15 +141,20 @@ where let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + Ok(if cast_options.safe { + array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + } else { + array.try_unary(|x| { + f(x).ok_or_else(|| error(x)) + .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) + })? }) } // Only support one type of decimal cast operations pub(crate) fn cast_decimal_to_decimal_same_type( array: &PrimitiveArray, + input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, @@ -155,20 +164,11 @@ where T: DecimalType, T::Native: DecimalCast + ArrowNativeTypeOp, { - let array: PrimitiveArray = match input_scale.cmp(&output_scale) { - Ordering::Equal => { - // the scale doesn't change, the native value don't need to be changed + let array: PrimitiveArray = + if input_scale == output_scale && input_precision <= output_precision { array.clone() - } - Ordering::Greater => convert_to_smaller_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )?, - Ordering::Less => { - // input_scale < output_scale + } else if input_scale < output_scale { + // the scale doesn't change, but precision may change and cause overflow convert_to_bigger_or_equal_scale_decimal::( array, input_scale, @@ -176,8 +176,15 @@ where output_scale, cast_options, )? - } - }; + } else { + convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + }; Ok(Arc::new(array.with_precision_and_scale( output_precision, diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 5c892783d4a..ba470635c6c 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -830,18 +830,20 @@ pub fn cast_with_options( (Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => { cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } - (Decimal128(_, s1), Decimal128(p2, s2)) => { + (Decimal128(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), + *p1, *s1, *p2, *s2, cast_options, ) } - (Decimal256(_, s1), Decimal256(p2, s2)) => { + (Decimal256(p1, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), + *p1, *s1, *p2, *s2, @@ -2694,13 +2696,16 @@ mod tests { // negative test let array = vec![Some(123456), None]; let array = create_decimal_array(array, 10, 0).unwrap(); - let result = cast(&array, &DataType::Decimal128(2, 2)); - assert!(result.is_ok()); - let array = result.unwrap(); - let array: &Decimal128Array = array.as_primitive(); - let err = array.validate_decimal_precision(2); + let result_safe = cast(&array, &DataType::Decimal128(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options); assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", - err.unwrap_err().to_string()); + result_unsafe.unwrap_err().to_string()); } #[test] @@ -8460,7 +8465,7 @@ mod tests { let input_type = DataType::Decimal128(10, 3); let output_type = DataType::Decimal256(10, 5); assert!(can_cast_types(&input_type, &output_type)); - let array = vec![Some(i128::MAX), Some(i128::MIN)]; + let array = vec![Some(123456), Some(-123456)]; let input_decimal_array = create_decimal_array(array, 10, 3).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; @@ -8470,8 +8475,8 @@ mod tests { Decimal256Array, &output_type, vec![ - Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)), - Some(i256::from_i128(i128::MIN).mul_wrapping(hundred)) + Some(i256::from_i128(123456).mul_wrapping(hundred)), + Some(i256::from_i128(-123456).mul_wrapping(hundred)) ] ); } @@ -9935,4 +9940,76 @@ mod tests { "Cast non-nullable to non-nullable struct field returning null should fail", ); } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + println!("{:?}", array); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal128(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + println!("{:?}", array); + let input_type = DataType::Decimal128(24, 4); + let output_type = DataType::Decimal128(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + println!("{:?}", array); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal128(6, 3); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999"); + } + + #[test] + fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() { + let array = vec![Some(123456789)]; + let array = create_decimal_array(array, 24, 2).unwrap(); + println!("{:?}", array); + let input_type = DataType::Decimal128(24, 2); + let output_type = DataType::Decimal256(6, 2); + assert!(can_cast_types(&input_type, &output_type)); + + let options = CastOptions { + safe: false, + ..Default::default() + }; + let result = cast_with_options(&array, &output_type, &options); + assert_eq!(result.unwrap_err().to_string(), + "Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999"); + } }