diff --git a/README.md b/README.md index d05b44b..97d97db 100644 --- a/README.md +++ b/README.md @@ -73,3 +73,4 @@ pip install pre-commit pre-commit run --all-files pytest -v ./tests ``` +Python wheel can be built with the usual command `python -m build`. diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index bbc5164..74eea10 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from . import lax, ops +from . import core, lax, ops from ._version import __version__ from .core import ( # noqa: F401 AutoScaleConfig, diff --git a/pyproject.toml b/pyproject.toml index 9b171d6..b093c06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1" description="JAX Scaled Arithmetics." readme = "README.md" authors = [ - { name = "Graphcore Research", email = "contact@graphcore.ai" }, + { name = "Graphcore Research", email = "paulb@graphcore.ai" }, ] requires-python = ">=3.8" classifiers = [ @@ -31,7 +31,7 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme" test = ["pytest"] [tool.setuptools] -packages = ["jax_scaled_arithmetics"] +packages = ["jax_scaled_arithmetics", "jax_scaled_arithmetics.core", "jax_scaled_arithmetics.lax", "jax_scaled_arithmetics.ops"] [tool.pytest.ini_options] minversion = "6.0"