Skip to content

Commit

Permalink
Merge pull request #168 from vfdev-5:enable-free-threading-mode
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672562730
  • Loading branch information
The ml_dtypes Authors committed Sep 9, 2024
2 parents 17a83f1 + 9b22feb commit 6f02f77
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 3 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,31 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-free-threading:
# Later we can merge this job with build similarly to
# https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57
name: Python 3.13 with free-threading
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
with:
submodules: true
- name: Set up Python 3.13 with free-threading
# TODO: replace with setup-python when there is support
uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0
with:
python-version: '3.13-dev'
nogil: true
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -U --pre numpy \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
9 changes: 6 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ jobs:
platforms: all

- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.15.0
run: python -m pip install cibuildwheel==2.20.0

- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-*
CIBW_SKIP: "*musllinux* *i686* *win32*"
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313*
CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist
CIBW_TEST_COMMAND: pytest -n auto {project}
CIBW_BUILD_VERBOSITY: 1

- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
Expand Down
5 changes: 5 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
reinterpret_cast<PyObject*>(TypeDescriptor<uint4>::type_ptr)) < 0) {
return nullptr;
}

#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED);
#endif

return m.release();
}
} // namespace ml_dtypes
21 changes: 21 additions & 0 deletions ml_dtypes/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The ml_dtypes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""pytest configuration file."""

import pathlib
import sys

# Add ml_dtypes/tests folder to discover multi_thread_utils.py module
sys.path.insert(0, str(pathlib.Path(__file__).absolute().parent))
18 changes: 18 additions & 0 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
import numpy as np

bfloat16 = ml_dtypes.bfloat16
Expand Down Expand Up @@ -196,6 +197,10 @@ def dtype_has_inf(dtype):


# pylint: disable=g-complex-comprehension
@multi_threaded(
num_workers=3,
skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down Expand Up @@ -604,6 +609,19 @@ def testDtypeFromString(self, float_type):


# pylint: disable=g-complex-comprehension
@multi_threaded(
num_workers=3,
skip_tests=[
"testBinaryUfunc",
"testConformNumpyComplex",
"testFloordivCornerCases",
"testDivmodCornerCases",
"testSpacing",
"testUnaryUfunc",
"testCasts",
"testLdexp",
],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down
2 changes: 2 additions & 0 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
import numpy as np

ALL_DTYPES = [
Expand All @@ -41,6 +42,7 @@
}


@multi_threaded(num_workers=3)
class FinfoTest(parameterized.TestCase):

def assertNanEqual(self, x, y):
Expand Down
2 changes: 2 additions & 0 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
import numpy as np


@multi_threaded(num_workers=3)
class IinfoTest(parameterized.TestCase):

def testIinfoInt2(self):
Expand Down
3 changes: 3 additions & 0 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
import numpy as np

int2 = ml_dtypes.int2
Expand All @@ -48,6 +49,7 @@ def ignore_warning(**kw):


# Tests for the Python scalar type
@multi_threaded(num_workers=3)
class ScalarTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -245,6 +247,7 @@ def testCanCast(self, a, b):


# Tests for the Python scalar type
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
class ArrayTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down
2 changes: 2 additions & 0 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from absl.testing import absltest
import ml_dtypes
from multi_thread_utils import multi_threaded


@multi_threaded(num_workers=3)
class CustomFloatTest(absltest.TestCase):

def test_version_matches_package_metadata(self):
Expand Down
50 changes: 50 additions & 0 deletions ml_dtypes/tests/multi_thread_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 The ml_dtypes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for multi-threaded tests."""

import concurrent.futures
import functools
from typing import Optional


def multi_threaded(*, num_workers: int, skip_tests: Optional[list[str]] = None):
"""Decorator that runs a test in a multi-threaded environment."""

def decorator(test_cls):
for name, test_fn in test_cls.__dict__.copy().items():
if not (name.startswith("test") and callable(test_fn)):
continue

if skip_tests is not None:
if any(test_name in name for test_name in skip_tests):
continue

@functools.wraps(test_fn) # pylint: disable=cell-var-from-loop
def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs):
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
futures.append(executor.submit(__test_fn__, *args, **kwargs))
# We should call future.result() to re-raise an exception if test has
# failed
list(f.result() for f in futures)

setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)

return test_cls

return decorator
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"numpy>=1.21.2; python_version>='3.10'",
"numpy>=1.23.3; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"numpy>=2.1.0; python_version>='3.13'",
]

[project.urls]
Expand Down

0 comments on commit 6f02f77

Please sign in to comment.