Skip to content

Commit

Permalink
Merge pull request #172 from jakevdp:divmod
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666881270
  • Loading branch information
The ml_dtypes Authors committed Aug 23, 2024
2 parents 30f2497 + 2975e8e commit 6c9775f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Added new 8-bit float type following IEEE 754 convention:
`ml_dtypes.float8_e4m3`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.

## [0.4.0] - 2024-04-1

Expand Down
8 changes: 7 additions & 1 deletion ml_dtypes/_src/ufuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ struct TrueDivide {
inline std::pair<float, float> divmod(float a, float b) {
if (b == 0.0f) {
float nan = std::numeric_limits<float>::quiet_NaN();
return {nan, nan};
float inf = std::numeric_limits<float>::infinity();

if (std::isnan(a) || (a == 0.0f)) {
return {nan, nan};
} else {
return {std::signbit(a) == std::signbit(b) ? inf : -inf, nan};
}
}
float mod = std::fmod(a, b);
float div = (a - mod) / b;
Expand Down
41 changes: 41 additions & 0 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,47 @@ def testDivmod(self, float_type):
float_type=float_type,
)

@ignore_warning(category=RuntimeWarning, message="invalid value encountered")
@ignore_warning(category=RuntimeWarning, message="divide by zero encountered")
def testDivmodCornerCases(self, float_type):
x = np.array(
[-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan],
dtype=float_type,
)
xf32 = x.astype("float32")
out = np.divmod.outer(x, x)
expected = np.divmod.outer(xf32, xf32)
numpy_assert_allclose(
out[0],
truncate(expected[0], float_type=float_type),
rtol=0.0,
float_type=float_type,
)
numpy_assert_allclose(
out[1],
truncate(expected[1], float_type=float_type),
rtol=0.0,
float_type=float_type,
)

@ignore_warning(category=RuntimeWarning, message="invalid value encountered")
@ignore_warning(category=RuntimeWarning, message="divide by zero encountered")
def testFloordivCornerCases(self, float_type):
# Regression test for https://github.com/jax-ml/ml_dtypes/issues/170
x = np.array(
[-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan],
dtype=float_type,
)
xf32 = x.astype("float32")
out = np.floor_divide.outer(x, x)
expected = np.floor_divide.outer(xf32, xf32)
numpy_assert_allclose(
out,
truncate(expected, float_type=float_type),
rtol=0.0,
float_type=float_type,
)

def testModf(self, float_type):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(float_type)
Expand Down

0 comments on commit 6c9775f

Please sign in to comment.