Skip to content

Commit

Permalink
For decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 27, 2023
1 parent c7e573c commit c26baeb
Showing 1 changed file with 57 additions and 18 deletions.
75 changes: 57 additions & 18 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Time64(_), Time32(to_unit)) => {
matches!(to_unit, Second | Millisecond)
}
(Timestamp(_, _), _) if to_type.is_integer() || (to_type.is_floating() && to_type != &Float16) => true,
(_, Timestamp(_, _)) if from_type.is_integer() || (from_type.is_floating() && from_type != &Float16) => true,
(Timestamp(_, _), _) if to_type.is_numeric() && to_type != &Float16 => true,
(_, Timestamp(_, _)) if from_type.is_numeric() && from_type != &Float16 => true,
(Date64, Timestamp(_, None)) => true,
(Date32, Timestamp(_, None)) => true,
(
Expand Down Expand Up @@ -876,7 +876,7 @@ pub fn cast_with_options(
cast_options,
)
}
(Decimal128(_, scale), _) => {
(Decimal128(_, scale), _) if !to_type.is_temporal() => {
// cast decimal to other type
match to_type {
UInt8 => cast_decimal_to_integer::<Decimal128Type, UInt8Type>(
Expand Down Expand Up @@ -941,7 +941,7 @@ pub fn cast_with_options(
))),
}
}
(Decimal256(_, scale), _) => {
(Decimal256(_, scale), _) if !to_type.is_temporal() => {
// cast decimal to other type
match to_type {
UInt8 => cast_decimal_to_integer::<Decimal256Type, UInt8Type>(
Expand Down Expand Up @@ -1006,7 +1006,7 @@ pub fn cast_with_options(
))),
}
}
(_, Decimal128(precision, scale)) => {
(_, Decimal128(precision, scale)) if !from_type.is_temporal() => {
// cast data to decimal
match from_type {
UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>(
Expand Down Expand Up @@ -1095,7 +1095,7 @@ pub fn cast_with_options(
))),
}
}
(_, Decimal256(precision, scale)) => {
(_, Decimal256(precision, scale)) if !from_type.is_temporal() => {
// cast data to decimal
match from_type {
UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>(
Expand Down Expand Up @@ -1634,31 +1634,25 @@ pub fn cast_with_options(
.unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)),
)),

// Timestamp to integer/floating
(Timestamp(TimeUnit::Second, _), _) if to_type.is_integer() || to_type.is_floating() => {
// Timestamp to integer/floating/decimals
(Timestamp(TimeUnit::Second, _), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<TimestampSecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Timestamp(TimeUnit::Millisecond, _), _)
if to_type.is_integer() || to_type.is_floating() =>
{
(Timestamp(TimeUnit::Millisecond, _), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<TimestampMillisecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Timestamp(TimeUnit::Microsecond, _), _)
if to_type.is_integer() || to_type.is_floating() =>
{
(Timestamp(TimeUnit::Microsecond, _), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<TimestampMicrosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Timestamp(TimeUnit::Nanosecond, _), _)
if to_type.is_integer() || to_type.is_floating() =>
{
(Timestamp(TimeUnit::Nanosecond, _), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<TimestampNanosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}

(_, Timestamp(unit, tz)) if from_type.is_integer() || from_type.is_floating() => {
(_, Timestamp(unit, tz)) if from_type.is_numeric() => {
let array = cast_with_options(array, &Int64, cast_options)?;
Ok(make_timestamp_array(
array.as_primitive(),
Expand Down Expand Up @@ -4755,6 +4749,51 @@ mod tests {
assert_eq!(&actual, &expected);
}

#[test]
fn test_cast_decimal_to_timestamp() {
let array = Int64Array::from(vec![Some(2), Some(10), None]);
let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

let array = Decimal128Array::from(vec![Some(200), Some(1000), None])
.with_precision_and_scale(4, 2)
.unwrap();
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

assert_eq!(&actual, &expected);

let array = Decimal256Array::from(vec![
Some(i256::from_i128(2000)),
Some(i256::from_i128(10000)),
None,
])
.with_precision_and_scale(5, 3)
.unwrap();
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

assert_eq!(&actual, &expected);
}

#[test]
fn test_cast_timestamp_to_decimal() {
let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None])
.with_timezone("UTC".to_string());
let expected = cast(&array, &DataType::Int64).unwrap();

let actual = cast(
&cast(&array, &DataType::Decimal128(5, 2)).unwrap(),
&DataType::Int64,
)
.unwrap();
assert_eq!(&actual, &expected);

let actual = cast(
&cast(&array, &DataType::Decimal256(10, 5)).unwrap(),
&DataType::Int64,
)
.unwrap();
assert_eq!(&actual, &expected);
}

#[test]
fn test_cast_list_i32_to_list_u16() {
let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data();
Expand Down

0 comments on commit c26baeb

Please sign in to comment.