Skip to content

Commit

Permalink
fix dtype promotion table in array api module (Bears-R-Us#3655)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado authored Aug 14, 2024
1 parent 10ea6b4 commit bff8ca5
Showing 1 changed file with 90 additions and 89 deletions.
179 changes: 90 additions & 89 deletions arkouda/array_api/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import numpy as np
import arkouda as ak
from arkouda import dtype as akdtype

# Note: we use dtype objects instead of dtype classes. The spec does not
# require any behavior on dtypes other than equality.
int8 = np.int8
int16 = np.int16
int32 = np.int32
int64 = np.int64
uint8 = np.uint8
uint16 = np.uint16
uint32 = np.uint32
uint64 = np.uint64
float32 = np.float32
float64 = np.float64
complex64 = np.complex64
complex128 = np.complex128
bool_ = np.bool_
int8 = ak.int8
int16 = ak.int16
int32 = ak.int32
int64 = ak.int64
uint8 = ak.uint8
uint16 = ak.uint16
uint32 = ak.uint32
uint64 = ak.uint64
float32 = ak.float32
float64 = ak.float64
complex64 = ak.complex64
complex128 = ak.complex128
bool_ = ak.bool_

_all_dtypes = (
int8,
Expand Down Expand Up @@ -97,83 +98,83 @@
# allowed to promote to floating-point dtypes, but only in array operators
# (see Array._promote_scalar) method in _array_object.py.
_promotion_table = {
(int8, int8): int8,
(int8, int16): int16,
(int8, int32): int32,
(int8, int64): int64,
(int16, int8): int16,
(int16, int16): int16,
(int16, int32): int32,
(int16, int64): int64,
(int32, int8): int32,
(int32, int16): int32,
(int32, int32): int32,
(int32, int64): int64,
(int64, int8): int64,
(int64, int16): int64,
(int64, int32): int64,
(int64, int64): int64,
(uint8, uint8): uint8,
(uint8, uint16): uint16,
(uint8, uint32): uint32,
(uint8, uint64): uint64,
(uint16, uint8): uint16,
(uint16, uint16): uint16,
(uint16, uint32): uint32,
(uint16, uint64): uint64,
(uint32, uint8): uint32,
(uint32, uint16): uint32,
(uint32, uint32): uint32,
(uint32, uint64): uint64,
(uint64, uint8): uint64,
(uint64, uint16): uint64,
(uint64, uint32): uint64,
(uint64, uint64): uint64,
(int8, uint8): int16,
(int8, uint16): int32,
(int8, uint32): int64,
(int16, uint8): int16,
(int16, uint16): int32,
(int16, uint32): int64,
(int32, uint8): int32,
(int32, uint16): int32,
(int32, uint32): int64,
(int64, uint8): int64,
(int64, uint16): int64,
(int64, uint32): int64,
(uint8, int8): int16,
(uint16, int8): int32,
(uint32, int8): int64,
(uint8, int16): int16,
(uint16, int16): int32,
(uint32, int16): int64,
(uint8, int32): int32,
(uint16, int32): int32,
(uint32, int32): int64,
(uint8, int64): int64,
(uint16, int64): int64,
(uint32, int64): int64,
(float32, float32): float32,
(float32, float64): float64,
(float64, float32): float64,
(float64, float64): float64,
(complex64, complex64): complex64,
(complex64, complex128): complex128,
(complex128, complex64): complex128,
(complex128, complex128): complex128,
(float32, complex64): complex64,
(float32, complex128): complex128,
(float64, complex64): complex128,
(float64, complex128): complex128,
(complex64, float32): complex64,
(complex64, float64): complex128,
(complex128, float32): complex128,
(complex128, float64): complex128,
(bool, bool): bool,
(akdtype(int8), akdtype(int8)): int8,
(akdtype(int8), akdtype(int16)): int16,
(akdtype(int8), akdtype(int32)): int32,
(akdtype(int8), akdtype(int64)): int64,
(akdtype(int16), akdtype(int8)): int16,
(akdtype(int16), akdtype(int16)): int16,
(akdtype(int16), akdtype(int32)): int32,
(akdtype(int16), akdtype(int64)): int64,
(akdtype(int32), akdtype(int8)): int32,
(akdtype(int32), akdtype(int16)): int32,
(akdtype(int32), akdtype(int32)): int32,
(akdtype(int32), akdtype(int64)): int64,
(akdtype(int64), akdtype(int8)): int64,
(akdtype(int64), akdtype(int16)): int64,
(akdtype(int64), akdtype(int32)): int64,
(akdtype(int64), akdtype(int64)): int64,
(akdtype(uint8), akdtype(uint8)): uint8,
(akdtype(uint8), akdtype(uint16)): uint16,
(akdtype(uint8), akdtype(uint32)): uint32,
(akdtype(uint8), akdtype(uint64)): uint64,
(akdtype(uint16), akdtype(uint8)): uint16,
(akdtype(uint16), akdtype(uint16)): uint16,
(akdtype(uint16), akdtype(uint32)): uint32,
(akdtype(uint16), akdtype(uint64)): uint64,
(akdtype(uint32), akdtype(uint8)): uint32,
(akdtype(uint32), akdtype(uint16)): uint32,
(akdtype(uint32), akdtype(uint32)): uint32,
(akdtype(uint32), akdtype(uint64)): uint64,
(akdtype(uint64), akdtype(uint8)): uint64,
(akdtype(uint64), akdtype(uint16)): uint64,
(akdtype(uint64), akdtype(uint32)): uint64,
(akdtype(uint64), akdtype(uint64)): uint64,
(akdtype(int8), akdtype(uint8)): int16,
(akdtype(int8), akdtype(uint16)): int32,
(akdtype(int8), akdtype(uint32)): int64,
(akdtype(int16), akdtype(uint8)): int16,
(akdtype(int16), akdtype(uint16)): int32,
(akdtype(int16), akdtype(uint32)): int64,
(akdtype(int32), akdtype(uint8)): int32,
(akdtype(int32), akdtype(uint16)): int32,
(akdtype(int32), akdtype(uint32)): int64,
(akdtype(int64), akdtype(uint8)): int64,
(akdtype(int64), akdtype(uint16)): int64,
(akdtype(int64), akdtype(uint32)): int64,
(akdtype(uint8), akdtype(int8)): int16,
(akdtype(uint16), akdtype(int8)): int32,
(akdtype(uint32), akdtype(int8)): int64,
(akdtype(uint8), akdtype(int16)): int16,
(akdtype(uint16), akdtype(int16)): int32,
(akdtype(uint32), akdtype(int16)): int64,
(akdtype(uint8), akdtype(int32)): int32,
(akdtype(uint16), akdtype(int32)): int32,
(akdtype(uint32), akdtype(int32)): int64,
(akdtype(uint8), akdtype(int64)): int64,
(akdtype(uint16), akdtype(int64)): int64,
(akdtype(uint32), akdtype(int64)): int64,
(akdtype(float32), akdtype(float32)): float32,
(akdtype(float32), akdtype(float64)): float64,
(akdtype(float64), akdtype(float32)): float64,
(akdtype(float64), akdtype(float64)): float64,
(akdtype(complex64), akdtype(complex64)): complex64,
(akdtype(complex64), akdtype(complex128)): complex128,
(akdtype(complex128), akdtype(complex64)): complex128,
(akdtype(complex128), akdtype(complex128)): complex128,
(akdtype(float32), akdtype(complex64)): complex64,
(akdtype(float32), akdtype(complex128)): complex128,
(akdtype(float64), akdtype(complex64)): complex128,
(akdtype(float64), akdtype(complex128)): complex128,
(akdtype(complex64), akdtype(float32)): complex64,
(akdtype(complex64), akdtype(float64)): complex128,
(akdtype(complex128), akdtype(float32)): complex128,
(akdtype(complex128), akdtype(float64)): complex128,
(akdtype(bool_), akdtype(bool_)): bool_,
}


def _result_type(type1, type2):
if (type1, type2) in _promotion_table:
return _promotion_table[type1, type2]
if (akdtype(type1), akdtype(type2)) in _promotion_table:
return _promotion_table[akdtype(type1), akdtype(type2)]
raise TypeError(f"{type1} and {type2} cannot be type promoted together")

0 comments on commit bff8ca5

Please sign in to comment.