Skip to content

Commit

Permalink
Merge pull request #1071 from carlosgmartin:fix_test_sh
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678248900
  • Loading branch information
OptaxDev committed Sep 24, 2024
2 parents 25f870b + 033ffbf commit f9807cc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,4 @@ valid-classmethod-first-arg=cls,

# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal
50 changes: 23 additions & 27 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,79 +14,75 @@
# ==============================================================================

# Runs CI tests on a local machine.
set -xeuo pipefail
set -euo pipefail

# Install deps in a virtual env.
rm -rf _testing
rm -rf dist/
rm -rf *.whl
rm -rf .pytype
mkdir -p _testing
readonly VENV_DIR="$(mktemp -d `pwd`/_testing/optax-env.XXXXXXXX)"
readonly VENV_DIR="$(mktemp -d $(pwd)/_testing/optax-env.XXXXXXXX)"
# in the unlikely case in which there was something in that directory
python3 -m venv "${VENV_DIR}"
source "${VENV_DIR}/bin/activate"
python --version

# Install dependencies.
pip install -q --upgrade pip setuptools wheel
pip install -q flake8 pytest-xdist pylint pylint-exit
pip install -q -e ".[test, examples]"
python3 -m pip install -q --upgrade pip setuptools wheel
python3 -m pip install -q flake8 pytest-xdist pylint pylint-exit
python3 -m pip install -q -e ".[test, examples]"

# Dp-accounting specifies exact minor versions as requirements which sometimes
# become incompatible with other libraries optax needs. We therefore install
# dependencies for dp-accounting manually.
# TODO(b/239416992): Remove this workaround if dp-accounting switches to minimum
# version requirements.
pip install -q -e ".[dp-accounting]"
pip install -q "dp-accounting>=0.1.1" --no-deps
python3 -m pip install -q -e ".[dp-accounting]"
python3 -m pip install -q "dp-accounting>=0.1.1" --no-deps

# Install the requested JAX version
if [ -z "${JAX_VERSION-}" ]; then
: # use version installed in requirements above
elif [ "$JAX_VERSION" = "newest" ]; then
pip install -U jax jaxlib
python3 -m pip install -U jax jaxlib
else
pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"
python3 -m pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"
fi

# Ensure optax was not installed by one of the dependencies above,
# since if it is, the tests below will be run against that version instead of
# the branch build.
pip uninstall -q -y optax || true
python3 -m pip uninstall -q -y optax || true

# Lint with flake8.
flake8 `find optax examples -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
flake8 $(find optax examples -name '*.py' | xargs) --select=E9,F63,F7,F82,E225,E251 --show-source --statistics

# Lint with pylint.
PYLINT_ARGS="-efail -wfail -cfail -rfail"
# Append specific config lines.
echo "disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal" >> .pylintrc
# Lint modules and tests separately.
pylint --rcfile=.pylintrc `find optax examples -name '*.py' | grep -v 'test.py' | xargs` -d E1102 || pylint-exit $PYLINT_ARGS $?
# Disable `protected-access` warnings for tests.
pylint --rcfile=.pylintrc `find optax examples -name '*_test.py' | xargs` -d W0212,E1102 || pylint-exit $PYLINT_ARGS $?
# Cleanup.
rm .pylintrc
pylint --rcfile=.pylintrc $(find optax examples -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $?
# Disable protected-access warnings for tests.
pylint --rcfile=.pylintrc $(find optax examples -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $?

# Build the package.
pip install build
python -m build
pip wheel --verbose --no-deps --no-clean dist/optax*.tar.gz
pip install optax*.whl
python3 -m pip install build
python3 -m build
python3 -m pip wheel --verbose --no-deps --no-clean dist/optax*.tar.gz
python3 -m pip install optax*.whl

# Check types with pytype.
pip install -q pytype
pytype `find optax/_src examples optax/contrib -name '*.py' | xargs` -k -d import-error
python3 -m pip install -q pytype
pytype $(find optax/_src examples optax/contrib -name '*.py' | xargs) -k -d import-error

# Run tests using pytest.
# Change directory to avoid importing the package from repo root.
cd _testing
python -m pytest -n auto --pyargs optax
python3 -m pytest -n auto --pyargs optax
cd ..

# Build Sphinx docs.
pip install -q -e ".[docs]"
python3 -m pip install -q -e ".[docs]"
# NOTE(vroulet) We have dependencies issues:
# tensorflow > 2.13.1 requires ml-dtypes <= 0.3.2
# but jax requires ml-dtypes >= 0.4.0
Expand All @@ -96,7 +92,7 @@ pip install -q -e ".[docs]"
# bug (which issues conflict warnings but runs fine).
# A long term solution is probably to fully remove tensorflow from our
# dependencies.
pip install --upgrade -v typing_extensions
python3 -m pip install --upgrade -v typing_extensions
cd docs && make html
# run doctests
make doctest
Expand Down

0 comments on commit f9807cc

Please sign in to comment.