diff --git a/README.md b/README.md index ba34d55e..74403006 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ array([0, 0, 0, 0], dtype=bfloat16) ``` Importing `ml_dtypes` also registers the data types with numpy, so that they may be referred to by their string name: + ```python >>> np.dtype('bfloat16') dtype(bfloat16) @@ -126,6 +127,7 @@ If you're exploring the use of low-precision dtypes in your code, you should be careful to anticipate when the precision loss might lead to surprising results. One example is the behavior of aggregations like `sum`; consider this `bfloat16` summation in NumPy (run with version 1.24.2): + ```python >>> from ml_dtypes import bfloat16 >>> import numpy as np @@ -137,17 +139,20 @@ summation in NumPy (run with version 1.24.2): The true sum should be close to 5000, but numpy returns exactly 256: this is because `bfloat16` does not have the precision to increment `256` by values less than `1`: + ```python >>> bfloat16(256) + bfloat16(1) 256 ``` After 256, the next representable value in bfloat16 is 258: + ```python >>> np.nextafter(bfloat16(256), bfloat16(np.inf)) 258 ``` For better results you can specify that the accumulation should happen in a higher-precision type like `float32`: + ```python >>> vals.sum(dtype='float32').astype(bfloat16) 4992 @@ -155,6 +160,7 @@ higher-precision type like `float32`: In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically: + ```python >>> import jax.numpy as jnp >>> jnp.array(vals).sum()