Skip to content

Commit

Permalink
Fix markdown.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586948558
  • Loading branch information
The ml_dtypes Authors committed Dec 1, 2023
1 parent e844eee commit d0519a5
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ array([0, 0, 0, 0], dtype=bfloat16)
```
Importing `ml_dtypes` also registers the data types with numpy, so that they may
be referred to by their string name:

```python
>>> np.dtype('bfloat16')
dtype(bfloat16)
Expand Down Expand Up @@ -126,6 +127,7 @@ If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like `sum`; consider this `bfloat16`
summation in NumPy (run with version 1.24.2):

```python
>>> from ml_dtypes import bfloat16
>>> import numpy as np
Expand All @@ -137,24 +139,28 @@ summation in NumPy (run with version 1.24.2):
The true sum should be close to 5000, but numpy returns exactly 256: this is
because `bfloat16` does not have the precision to increment `256` by values less than
`1`:

```python
>>> bfloat16(256) + bfloat16(1)
256
```
After 256, the next representable value in bfloat16 is 258:

```python
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
```
For better results you can specify that the accumulation should happen in a
higher-precision type like `float32`:

```python
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
```
In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support
low-precision arithmetic more natively will often do these kinds of higher-precision
accumulations automatically:

```python
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Expand Down

0 comments on commit d0519a5

Please sign in to comment.