Skip to content

Commit

Permalink
Added CPython free-threading support and basic CI env
Browse files Browse the repository at this point in the history
+ multithread tests
+ wheels job
  • Loading branch information
vfdev-5 committed Aug 26, 2024
1 parent cfbd8ac commit 47a377a
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 3 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,31 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-free-threading:
# Later we can merge this job with build similarly to
# https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57
name: Python 3.13 with free-threading
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
with:
submodules: true
- name: Set up Python 3.13 with free-threading
# TODO: replace with setup-python when there is support
uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0
with:
python-version: '3.13-dev'
nogil: true
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -U --pre numpy \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
9 changes: 6 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ jobs:
platforms: all

- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.15.0
run: python -m pip install cibuildwheel==2.20.0

- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-*
CIBW_SKIP: "*musllinux* *i686* *win32*"
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313*
CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist
CIBW_TEST_COMMAND: pytest -n auto {project}
CIBW_BUILD_VERBOSITY: 1

- uses: actions/upload-artifact@v3
with:
Expand Down
5 changes: 5 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
reinterpret_cast<PyObject*>(TypeDescriptor<uint4>::type_ptr)) < 0) {
return nullptr;
}

#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED);
#endif

return m.release();
}
} // namespace ml_dtypes
Empty file added ml_dtypes/tests/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from ml_dtypes.tests.multi_thread_test_mixin import MultiThreadTestMixin

ALL_DTYPES = [
ml_dtypes.bfloat16,
Expand Down Expand Up @@ -109,5 +110,9 @@ def assert_zero(val):
)


class FinfoMultiThreadTest(FinfoTest, MultiThreadTestMixin):
pass


if __name__ == "__main__":
absltest.main()
5 changes: 5 additions & 0 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from ml_dtypes.tests.multi_thread_test_mixin import MultiThreadTestMixin


class IinfoTest(parameterized.TestCase):
Expand Down Expand Up @@ -79,5 +80,9 @@ def testIinfoNonInteger(self):
ml_dtypes.iinfo(bool)


class IinfoMultiThreadTest(IinfoTest, MultiThreadTestMixin):
pass


if __name__ == "__main__":
absltest.main()
9 changes: 9 additions & 0 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from ml_dtypes.tests.multi_thread_test_mixin import MultiThreadTestMixin

int2 = ml_dtypes.int2
int4 = ml_dtypes.int4
Expand Down Expand Up @@ -380,5 +381,13 @@ def testBinaryUfuncs(self, scalar_type, ufunc):
)


class ScalarMultiThreadTest(ScalarTest, MultiThreadTestMixin):
pass


class ArrayMultiThreadTest(ArrayTest, MultiThreadTestMixin):
pass


if __name__ == "__main__":
absltest.main()
5 changes: 5 additions & 0 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest
import ml_dtypes
from ml_dtypes.tests.multi_thread_test_mixin import MultiThreadTestMixin


class CustomFloatTest(absltest.TestCase):
Expand All @@ -31,5 +32,9 @@ def test_version_matches_package_metadata(self):
self.assertEqual(metadata_version, package_version)


class CustomFloatMultiThreadTest(CustomFloatTest, MultiThreadTestMixin):
pass


if __name__ == "__main__":
absltest.main()
26 changes: 26 additions & 0 deletions ml_dtypes/tests/multi_thread_test_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
class MultiThreadTestMixin:
max_workers: int = 2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
test_fn_names = [v for v in dir(self) if v.startswith("test")]

for test_fn_name in test_fn_names:
test_fn = getattr(self, test_fn_name)

def get_mt_test_fn(test_func):
from concurrent.futures import ThreadPoolExecutor
from functools import wraps

@wraps(test_func)
def wrapper(*args, **kwargs):
def test_func_noargs(_):
test_func(*args, **kwargs)

with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
list(executor.map(test_func_noargs, range(self.max_workers)))

wrapper.__doc__ == test_func.__doc__
return wrapper

setattr(self, test_fn.__name__, get_mt_test_fn(test_fn))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"numpy>=1.21.2; python_version>='3.10'",
"numpy>=1.23.3; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"numpy>=2.1.0; python_version>='3.13'",
]

[project.urls]
Expand Down

0 comments on commit 47a377a

Please sign in to comment.