Skip to content

Commit

Permalink
Make division safer; address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
milevin committed Dec 19, 2024
1 parent 255cf36 commit be54b4a
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions arrow-arith/src/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, A
DayTime => interval_mul_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
MonthDayNano => interval_mul_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
},
(lhs, Interval(unit)) if lhs.is_integer() && matches!(op, Op::Mul | Op::MulWrapping) =>
(lhs, Interval(unit)) if lhs.is_numeric() && matches!(op, Op::Mul | Op::MulWrapping) =>
match unit {
YearMonth => interval_mul_op::<IntervalYearMonthType>(op, l, l_scalar, r, r_scalar),
DayTime => interval_mul_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
Expand Down Expand Up @@ -574,6 +574,17 @@ trait IntervalOp: ArrowPrimitiveType {
fn div_float(left: Self::Native, right: f64) -> Result<Self::Native, ArrowError>;
}

/// Helper function to safely convert f64 to i32, checking for overflow and invalid values
fn f64_to_i32(value: f64) -> Result<i32, ArrowError> {
if !value.is_finite() || value > i32::MAX as f64 || value < i32::MIN as f64 {
Err(ArrowError::ComputeError(
"Division result out of i32 range".to_string(),
))
} else {
Ok(value as i32)
}
}

impl IntervalOp for IntervalYearMonthType {
fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
left.add_checked(right)
Expand All @@ -596,14 +607,18 @@ impl IntervalOp for IntervalYearMonthType {
if right == 0 {
return Err(ArrowError::DivideByZero);
}
Ok((left as f64 / right as f64).round() as i32)

let result = (left as f64 / right as f64).round();
f64_to_i32(result)
}

fn div_float(left: Self::Native, right: f64) -> Result<Self::Native, ArrowError> {
if right == 0.0 {
return Err(ArrowError::DivideByZero);
}
Ok((left as f64 / right).round() as i32)

let result = (left as f64 / right).round();
f64_to_i32(result)
}
}

Expand Down Expand Up @@ -664,7 +679,9 @@ impl IntervalOp for IntervalDayTimeType {
let result_days = result_ms / 86_400_000;
let result_ms = result_ms % 86_400_000;

Ok(Self::make_value(result_days as i32, result_ms as i32))
let result_days_i32 = f64_to_i32(result_days as f64)?;
let result_ms_i32 = f64_to_i32(result_ms as f64)?;
Ok(Self::make_value(result_days_i32, result_ms_i32))
}

fn div_float(left: Self::Native, right: f64) -> Result<Self::Native, ArrowError> {
Expand All @@ -680,10 +697,10 @@ impl IntervalOp for IntervalDayTimeType {
let result_days = (total_ms / 86_400_000.0).floor();
let result_ms = total_ms % 86_400_000.0;

Ok(Self::make_value(
result_days as i32,
result_ms.round() as i32,
))
let result_days_i32 = f64_to_i32(result_days)?;
let result_ms_i32 = f64_to_i32(result_ms)?;

Ok(Self::make_value(result_days_i32, result_ms_i32))
}
}

Expand Down

0 comments on commit be54b4a

Please sign in to comment.