diff --git a/.pylintrc b/.pylintrc index 5c771e62..b26aeee4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -397,3 +397,4 @@ valid-classmethod-first-arg=cls, # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs +disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal diff --git a/test.sh b/test.sh index 4781f4c3..b1b0c016 100755 --- a/test.sh +++ b/test.sh @@ -14,7 +14,7 @@ # ============================================================================== # Runs CI tests on a local machine. -set -xeuo pipefail +set -euo pipefail # Install deps in a virtual env. rm -rf _testing @@ -22,71 +22,67 @@ rm -rf dist/ rm -rf *.whl rm -rf .pytype mkdir -p _testing -readonly VENV_DIR="$(mktemp -d `pwd`/_testing/optax-env.XXXXXXXX)" +readonly VENV_DIR="$(mktemp -d $(pwd)/_testing/optax-env.XXXXXXXX)" # in the unlikely case in which there was something in that directory python3 -m venv "${VENV_DIR}" source "${VENV_DIR}/bin/activate" -python --version # Install dependencies. -pip install -q --upgrade pip setuptools wheel -pip install -q flake8 pytest-xdist pylint pylint-exit -pip install -q -e ".[test, examples]" +python3 -m pip install -q --upgrade pip setuptools wheel +python3 -m pip install -q flake8 pytest-xdist pylint pylint-exit +python3 -m pip install -q -e ".[test, examples]" # Dp-accounting specifies exact minor versions as requirements which sometimes # become incompatible with other libraries optax needs. We therefore install # dependencies for dp-accounting manually. # TODO(b/239416992): Remove this workaround if dp-accounting switches to minimum # version requirements. -pip install -q -e ".[dp-accounting]" -pip install -q "dp-accounting>=0.1.1" --no-deps +python3 -m pip install -q -e ".[dp-accounting]" +python3 -m pip install -q "dp-accounting>=0.1.1" --no-deps # Install the requested JAX version if [ -z "${JAX_VERSION-}" ]; then : # use version installed in requirements above elif [ "$JAX_VERSION" = "newest" ]; then - pip install -U jax jaxlib + python3 -m pip install -U jax jaxlib else - pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" + python3 -m pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" fi # Ensure optax was not installed by one of the dependencies above, # since if it is, the tests below will be run against that version instead of # the branch build. -pip uninstall -q -y optax || true +python3 -m pip uninstall -q -y optax || true # Lint with flake8. -flake8 `find optax examples -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics +flake8 $(find optax examples -name '*.py' | xargs) --select=E9,F63,F7,F82,E225,E251 --show-source --statistics # Lint with pylint. PYLINT_ARGS="-efail -wfail -cfail -rfail" # Append specific config lines. -echo "disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal" >> .pylintrc # Lint modules and tests separately. -pylint --rcfile=.pylintrc `find optax examples -name '*.py' | grep -v 'test.py' | xargs` -d E1102 || pylint-exit $PYLINT_ARGS $? -# Disable `protected-access` warnings for tests. -pylint --rcfile=.pylintrc `find optax examples -name '*_test.py' | xargs` -d W0212,E1102 || pylint-exit $PYLINT_ARGS $? -# Cleanup. -rm .pylintrc +pylint --rcfile=.pylintrc $(find optax examples -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $? +# Disable protected-access warnings for tests. +pylint --rcfile=.pylintrc $(find optax examples -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $? # Build the package. -pip install build -python -m build -pip wheel --verbose --no-deps --no-clean dist/optax*.tar.gz -pip install optax*.whl +python3 -m pip install build +python3 -m build +python3 -m pip wheel --verbose --no-deps --no-clean dist/optax*.tar.gz +python3 -m pip install optax*.whl # Check types with pytype. -pip install -q pytype -pytype `find optax/_src examples optax/contrib -name '*.py' | xargs` -k -d import-error +python3 -m pip install -q pytype +pytype $(find optax/_src examples optax/contrib -name '*.py' | xargs) -k -d import-error # Run tests using pytest. # Change directory to avoid importing the package from repo root. cd _testing -python -m pytest -n auto --pyargs optax +python3 -m pytest -n auto --pyargs optax cd .. # Build Sphinx docs. -pip install -q -e ".[docs]" +python3 -m pip install -q -e ".[docs]" # NOTE(vroulet) We have dependencies issues: # tensorflow > 2.13.1 requires ml-dtypes <= 0.3.2 # but jax requires ml-dtypes >= 0.4.0 @@ -96,7 +92,7 @@ pip install -q -e ".[docs]" # bug (which issues conflict warnings but runs fine). # A long term solution is probably to fully remove tensorflow from our # dependencies. -pip install --upgrade -v typing_extensions +python3 -m pip install --upgrade -v typing_extensions cd docs && make html # run doctests make doctest