Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax absolute tolerance for failing tests involving chex.assert_trees_all_close. #1069

Conversation

carlosgmartin
Copy link
Contributor

No description provided.

Copy link
Member

@fabianp fabianp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@vroulet
Copy link
Collaborator

vroulet commented Sep 23, 2024

Thanks @carlosgmartin ! I don't see when or where these tests broke. Could you point out the reason for this change?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Sep 23, 2024

@vroulet Here's the output I get from running tests on my machine (after applying #1068 and #1071):

(venv) $ sh ./test.sh; tput bel

-------------------------------------------------------------------
Your code has been rated at 10.00/10 (previous run: 1.41/10, +8.59)


------------------------------------
Your code has been rated at 10.00/10

Collecting build
  Using cached build-1.2.2-py3-none-any.whl.metadata (6.2 kB)
Requirement already satisfied: packaging>=19.1 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from build) (24.1)
Collecting pyproject_hooks (from build)
  Using cached pyproject_hooks-1.1.0-py3-none-any.whl.metadata (1.3 kB)
Using cached build-1.2.2-py3-none-any.whl (22 kB)
Using cached pyproject_hooks-1.1.0-py3-none-any.whl (9.2 kB)
Installing collected packages: pyproject_hooks, build
Successfully installed build-1.2.2 pyproject_hooks-1.1.0
* Creating isolated environment: venv+pip...
* Installing packages in isolated environment:
  - flit_core >=3.2,<4
* Getting build dependencies for sdist...
* Building sdist...
Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440)
* Building wheel from sdist
* Creating isolated environment: venv+pip...
* Installing packages in isolated environment:
  - flit_core >=3.2,<4
* Getting build dependencies for wheel...
* Building wheel...
Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440)
Successfully built optax-0.2.4.dev0.tar.gz and optax-0.2.4.dev0-py3-none-any.whl
Processing ./dist/optax-0.2.4.dev0.tar.gz
  Running command pip subprocess to install build dependencies
  Using pip 24.2 from /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/pip (python 3.12)
  Collecting flit_core<4,>=3.2
    Obtaining dependency information for flit_core<4,>=3.2 from https://files.pythonhosted.org/packages/38/45/618e84e49a6c51e5dd15565ec2fcd82ab273434f236b8f108f065ded517a/flit_core-3.9.0-py3-none-any.whl.metadata
    Using cached flit_core-3.9.0-py3-none-any.whl.metadata (822 bytes)
  Using cached flit_core-3.9.0-py3-none-any.whl (63 kB)
  Installing collected packages: flit_core
  Successfully installed flit_core-3.9.0
  Installing build dependencies ... done
  Running command Getting requirements to build wheel
  Getting requirements to build wheel ... done
  Running command Preparing metadata (pyproject.toml)
  Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440)
  Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: optax
  Running command Building wheel for optax (pyproject.toml)
  Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440)
  Building wheel for optax (pyproject.toml) ... done
  Created wheel for optax: filename=optax-0.2.4.dev0-py3-none-any.whl size=302990 sha256=24c29cdf1a9e11eb65d0ec26ec9571a991c1fbacf5bd50f908d66c779f452462
  Stored in directory: /Users/carlos/Library/Caches/pip/wheels/12/8b/12/f73f3c3c2e327a62f3b987cf82f38ac6de00e4d01f1a02e9cd
Successfully built optax
Processing ./optax-0.2.4.dev0-py3-none-any.whl
Requirement already satisfied: absl-py>=0.7.1 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.4.0)
Requirement already satisfied: chex>=0.1.86 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.1.86)
Requirement already satisfied: jax>=0.4.27 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.4.33)
Requirement already satisfied: jaxlib>=0.4.27 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.4.33)
Requirement already satisfied: numpy>=1.18.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.26.4)
Requirement already satisfied: etils[epy] in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.9.4)
Requirement already satisfied: typing-extensions>=4.2.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (4.12.2)
Requirement already satisfied: toolz>=0.9.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (0.12.1)
Requirement already satisfied: setuptools in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (75.1.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (0.4.1)
Requirement already satisfied: opt-einsum in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (3.3.0)
Requirement already satisfied: scipy>=1.10 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (1.14.1)
Installing collected packages: optax
Successfully installed optax-0.2.4.dev0
Computing dependencies
Analyzing 48 sources with 20 local dependencies
ninja: Entering directory `.pytype'
[68/68] check constrain
Leaving directory '.pytype'
Success: no errors found
=============================== test session starts ================================
platform darwin -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0
rootdir: /Users/carlos/Desktop/optax
configfile: pyproject.toml
plugins: xdist-3.6.1
16 workers [3564 items]   
ss.....s..s........s...s.......s.......s..s......................ss......... [  2%]
.....s..ss....s......s....s...................s......s...s....s.......s..... [  4%]
.........s.....ss....s.s.....s.....s...s.s....s....s...s.s..s...........s... [  6%]
....s....s.s.s..s...s........s..s......s........s...s...s.s....s....s....s.. [  8%]
.s..s.s...s.....s.s....s.s.....s....s..s.....s..s......s.......s...s......s. [ 10%]
..s...s...ss.....s.s..s....s........s....s.....s.s...s...s.s..s.......s...ss [ 12%]
...s..........s.....s.s.......s.s.....s....s...s....s....s...s....s.s......s [ 14%]
.s.............s.s........s.s...s...sss.s....s....s...s.......s.........s.s. [ 17%]
....s.s..s.s..s.........s......s....s.........s..s...s.s...s....s........s.. [ 19%]
s....s....s.......s..s.....s.....s.s........s....s.s...s....s..s..........s. [ 21%]
...s.....s.....s.s....s.....s.....ss.......ss.s.......s..s.....s............ [ 23%]
.....ss.........s..s..s....s....s..s....s............s....s.....s.s......ss. [ 25%]
........s..........s..ss.........s...s....s......ss.s....s......s.....s..... [ 27%]
s..........s......s......s....s...s......s....s...........s......s....s..... [ 29%]
.....s.....s......s..s.......s.........s......s.s.s............s...s..ss.... [ 31%]
..s..s..............s......ss.....s.....s...s..s........s...........s..s.... [ 34%]
....s...s.........s.....s......s.......s..................s..s..s........... [ 36%]
..s.............s..............s..s........s......s......................s.. [ 38%]
s................s................s..s......s....s....s....s......s......... [ 40%]
.s........s....s..s...........s...s..................s.s..s...s..s.......... [ 42%]
...s........s.s...s..s..........s...s.s...........s.........ss....s......... [ 44%]
...............s..ss....s...s....s.s.s.s..........s............s..s......... [ 46%]
..s.s....s...s...s........s...s..s.....s.s.s.....s.s..s.s...........s....s.. [ 49%]
s.........s......s.......s.......s......s....s......s..s...s.....s.......... [ 51%]
..s.......s....s..s..s.......s...s.s.....sssss.s......s.s....s.............s [ 53%]
..s...s.....s...s.....s....s.s........s...........s...s.....s.....s.s....s.. [ 55%]
..s.....s.....s..s.s.....s.....s...s...s.........s......s...s.......s....... [ 57%]
...s.s.......s..ss....s.......s......s........s....s.s.........s...s...s.... [ 59%]
...s.....s.s..............s.s......s......s....s.............s.............. [ 61%]
..s....s.....s....s..........s.....s.s............s..s.s...........s...s.... [ 63%]
s.s.........s.....s..........s.........s....s..........s.s......s...s....s.. [ 66%]
...s.......s...s......s....s........s.....s.....s............s.s.....s...s.. [ 68%]
..s....s....s........s.....s.sF........s..............s.....s.......s.s...s. [ 70%]
........s........s.......s.s..........s.....s.....s.........s........s...s.. [ 72%]
...s..s.........s.........s....s....s.s........s...............s...........s [ 74%]
.s.........s.s...s........s.s..s..s.s..ss.....s.....ss.......s..........s.s. [ 76%]
..s....s......s....s.s.....s...s........s..s.....s.s.s.......s.......s.....s [ 78%]
.......s...........s...s........s....ss.s.......s....s.s.s.....s....s....... [ 81%]
..........s....s.....s.......s.s...s....s.s....s....s..s...s....s....s...... [ 83%]
...s.......ss.........s.....s.....s...sss.....s.....s.....s........s......s. [ 85%]
.s....s......s.....s.......s..s.....s.s...s...s.s.....s....s.....s......s..s [ 87%]
....s....s.s...............s...s.s........s....s...........s....s....s....s. [ 89%]
s..........s.s...s.s......s...s...s...s....s.s......s..s.s............s.s... [ 91%]
....s..s....ss.............s......s....s...s.ss..s...s..........s.....s..s.. [ 93%]
....s.s.....s......s......s......s......s.....s.....s.s........s.s.......ss. [ 95%]
......ss.s.....s............s............................................... [ 98%]
..................s...........s.........F................s........s.         [100%]
===================================== FAILURES =====================================
_______________________ ZoomLinesearchTest.test_linesearch7 ________________________
[gw1] darwin -- Python 3.12.3 /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/bin/python3

self = <linesearch_test.ZoomLinesearchTest testMethod=test_linesearch7>
problem_name = 'rosenbrock', seed = 1

    @parameterized.product(
        problem_name=[
            'polynomial', 'exponential', 'sinusoidal',
            'rosenbrock', 'himmelblau', 'matyas', 'eggholder'
            ],
        seed=[0, 1],
    )
    def test_linesearch(self, problem_name: str, seed: int):
      """Test backtracking linesearch (single update step)."""
      # Fixed tolerances, we check the behavior in standard conditions
      slope_rtol = 1e-4
      curv_rtol = 0.9
      tol = 0.
    
      key = jrd.PRNGKey(seed)
      params_key, precond_key = jrd.split(key, 2)
      problem = get_problem(problem_name)
      fn, input_shape = problem['fn'], problem['input_shape']
    
      init_params = jrd.normal(params_key, input_shape)
      precond_vec = jrd.uniform(precond_key, input_shape)
    
      # Mimics a preconditioning by a diagonal matrix with non-negative entries
      # (non-negativity ensures that we keep a descent direction)
      init_updates = -precond_vec*jax.grad(fn)(init_params)
    
      opt_args = dict(
          max_linesearch_steps=30,
          slope_rtol=slope_rtol,
          curv_rtol=curv_rtol,
          tol=tol,
          max_learning_rate=None
      )
    
      opt = _linesearch.scale_by_zoom_linesearch(**opt_args)
      final_params, final_state = _run_linesearch(
          opt, fn, init_params, init_updates
      )
    
      scipy_res = scipy_optimize.line_search(
          fn, jax.grad(fn), init_params, init_updates
      )
      with self.subTest('Check value and grad in zoom state'):
        self._check_value_and_grad_in_zoom_state(
            final_params, final_state, value_fn=fn
        )
      with self.subTest('Check linesearch conditions'):
        self._check_linesearch_conditions(
            fn, init_params, init_updates, final_params, final_state, opt_args
        )
      with self.subTest('Check against scipy'):
        stepsize = otu.tree_get(final_state, 'learning_rate')
        final_value = otu.tree_get(final_state, 'value')
        chex.assert_trees_all_close(scipy_res[0], stepsize, atol=1e-5)
>       chex.assert_trees_all_close(scipy_res[3], final_value, atol=1e-5)

optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/linesearch_test.py:483: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:278: in _chex_assert_fn
    host_assertion_fn(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

custom_message = None, custom_message_format_vars = ()
include_default_message = True, exception_type = <class 'AssertionError'>
args = (Array(3211.6665, dtype=float32), Array(3211.6628, dtype=float32))
kwargs = {'atol': 1e-05}
assertion_exc = AssertionError('[Chex] Assertion assert_trees_all_equal_comparator failed: Trees (arrays) 0 and 1 differ: \nNot equal ...34e-06\n x: array(3211.6665, dtype=float32)\n y: array(3211.6628, dtype=float32) \nOriginal dtypes: float32, float32.')
value_exc = None
error_msg = '[Chex] Assertion assert_trees_all_close failed:  Trees (arrays) 0 and 1 differ: \nNot equal to tolerance rtol=1e-06, ...534e-06\n x: array(3211.6665, dtype=float32)\n y: array(3211.6628, dtype=float32) \nOriginal dtypes: float32, float32.'
default_msg = 'Assertion assert_trees_all_close failed: '

    def _assert_on_host(*args,
                        custom_message: Optional[str] = None,
                        custom_message_format_vars: Sequence[Any] = (),
                        include_default_message: bool = True,
                        exception_type: Type[Exception] = AssertionError,
                        **kwargs) -> None:
      # Format error's stack trace to remove Chex' internal frames.
      assertion_exc = None
      value_exc = None
      try:
        assert_fn(*args, **kwargs)
      except AssertionError as e:
        assertion_exc = e
      except ValueError as e:
        value_exc = e
      finally:
        if value_exc is not None:
          raise ValueError(str(value_exc))
    
        if assertion_exc is not None:
          # Format the exception message.
          error_msg = str(assertion_exc)
    
          # Include only the name of the outermost chex assertion.
          if error_msg.startswith(ERR_PREFIX):
            error_msg = error_msg[error_msg.find("failed:") + len("failed:"):]
    
          # Whether to include the default error message.
          default_msg = (f"Assertion {name} failed: "
                         if include_default_message else "")
          error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}"
    
          # Whether to include a custom error message.
          if custom_message:
            if custom_message_format_vars:
              custom_message = custom_message.format(*custom_message_format_vars)
            error_msg = f"{error_msg} [{custom_message}]"
    
>         raise exception_type(error_msg)
E         AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees (arrays) 0 and 1 differ: 
E         Not equal to tolerance rtol=1e-06, atol=1e-05
E         Error in value equality check: Values not approximately equal
E         Mismatched elements: 1 / 1 (100%)
E         Max absolute difference: 0.00366211
E         Max relative difference: 1.1402534e-06
E          x: array(3211.6665, dtype=float32)
E          y: array(3211.6628, dtype=float32) 
E         Original dtypes: float32, float32.

optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:196: AssertionError
_______________________ LBFGSTest.test_plain_preconditioning _______________________
[gw12] darwin -- Python 3.12.3 /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/bin/python3

self = <alias_test.LBFGSTest testMethod=test_plain_preconditioning>

    def test_plain_preconditioning(self):
      key = jrd.PRNGKey(0)
      key_ws, key_us, key_vec = jrd.split(key, 3)
      m = 4
      d = 3
      dws = jrd.normal(key_ws, (m, d))
      dus = jrd.normal(key_us, (m, d))
      rhos = 1.0 / jnp.sum(dws * dus, axis=1)
      vec = jrd.normal(key_vec, (d,))
      plain_precond_vec = _plain_preconditioning(dws, dus, vec)
      precond_mat = _materialize_approx_inv_hessian(dws, dus, rhos, memory_idx=0)
      expected_precond_vec = precond_mat.dot(
          vec, precision=jax.lax.Precision.HIGHEST
      )
>     chex.assert_trees_all_close(plain_precond_vec, expected_precond_vec)

optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py:589: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:278: in _chex_assert_fn
    host_assertion_fn(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

custom_message = None, custom_message_format_vars = ()
include_default_message = True, exception_type = <class 'AssertionError'>
args = (Array([ 146.7743  ,   82.296425, -844.90216 ], dtype=float32), Array([ 146.77422,   82.29634, -844.9018 ], dtype=float32))
kwargs = {}
assertion_exc = AssertionError('[Chex] Assertion assert_trees_all_equal_comparator failed: Trees (arrays) 0 and 1 differ: \nNot equal ..., dtype=float32)\n y: array([ 146.77422,   82.29634, -844.9018 ], dtype=float32) \nOriginal dtypes: float32, float32.')
value_exc = None
error_msg = '[Chex] Assertion assert_trees_all_close failed:  Trees (arrays) 0 and 1 differ: \nNot equal to tolerance rtol=1e-06, ...], dtype=float32)\n y: array([ 146.77422,   82.29634, -844.9018 ], dtype=float32) \nOriginal dtypes: float32, float32.'
default_msg = 'Assertion assert_trees_all_close failed: '

    def _assert_on_host(*args,
                        custom_message: Optional[str] = None,
                        custom_message_format_vars: Sequence[Any] = (),
                        include_default_message: bool = True,
                        exception_type: Type[Exception] = AssertionError,
                        **kwargs) -> None:
      # Format error's stack trace to remove Chex' internal frames.
      assertion_exc = None
      value_exc = None
      try:
        assert_fn(*args, **kwargs)
      except AssertionError as e:
        assertion_exc = e
      except ValueError as e:
        value_exc = e
      finally:
        if value_exc is not None:
          raise ValueError(str(value_exc))
    
        if assertion_exc is not None:
          # Format the exception message.
          error_msg = str(assertion_exc)
    
          # Include only the name of the outermost chex assertion.
          if error_msg.startswith(ERR_PREFIX):
            error_msg = error_msg[error_msg.find("failed:") + len("failed:"):]
    
          # Whether to include the default error message.
          default_msg = (f"Assertion {name} failed: "
                         if include_default_message else "")
          error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}"
    
          # Whether to include a custom error message.
          if custom_message:
            if custom_message_format_vars:
              custom_message = custom_message.format(*custom_message_format_vars)
            error_msg = f"{error_msg} [{custom_message}]"
    
>         raise exception_type(error_msg)
E         AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees (arrays) 0 and 1 differ: 
E         Not equal to tolerance rtol=1e-06, atol=0
E         Error in value equality check: Values not approximately equal
E         Mismatched elements: 1 / 3 (33.3%)
E         Max absolute difference: 0.00036621
E         Max relative difference: 1.01977e-06
E          x: array([ 146.7743  ,   82.296425, -844.90216 ], dtype=float32)
E          y: array([ 146.77422,   82.29634, -844.9018 ], dtype=float32) 
E         Original dtypes: float32, float32.

optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:196: AssertionError
================================= warnings summary =================================
_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py: 6 warnings
_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_schedule_free_test.py: 4 warnings
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: DeprecationWarning: Casting from complex to real dtypes will soon raise a ValueError. Please first use jnp.real or jnp.imag to take the real/imaginary component of your input.
    return lax_numpy.astype(self, dtype, copy=copy, device=device)

_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py: 16 warnings
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/transform.py:1521: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    lambda leaf: jnp.zeros((memory_size,) + leaf.shape, dtype=leaf.dtype),

_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_common_test.py::ContribTest::test_optimizers_can_be_wrapped_in_inject_hyperparams_(opt_name='momo_adam', opt_kwargs={'learning_rate': 0.1}, wrapper_name=None, wrapper_kwargs=None)__without_device
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_momo.py:301: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    bc2 = jnp.asarray(1 - b2 ** count_inc, dtype=barf.dtype)

_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_common_test.py::ContribTest::test_optimizers_can_be_wrapped_in_inject_hyperparams_(opt_name='momo_adam', opt_kwargs={'learning_rate': 0.1}, wrapper_name=None, wrapper_kwargs=None)__without_device
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_momo.py:310: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    bc1 = jnp.asarray(1 - b1 ** count_inc, dtype=barf.dtype)

_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/tree_utils/_tree_math_test.py::TreeUtilsTest::test_tree_add_scalar_mul
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype complex128 requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return lax_numpy.astype(self, dtype, copy=copy, device=device)

_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/tree_utils/_tree_math_test.py::TreeUtilsTest::test_tree_add_scalar_mul
  /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return lax_numpy.astype(self, dtype, copy=copy, device=device)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================= short test summary info ==============================
FAILED optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/linesearch_test.py::ZoomLinesearchTest::test_linesearch7 - AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees (arrays)...
FAILED optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py::LBFGSTest::test_plain_preconditioning - AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees (arrays)...
============ 2 failed, 2955 passed, 607 skipped, 30 warnings in 34.51s =============
(venv) $ 

Do you get the same output from running tests on your machine?

@vroulet
Copy link
Collaborator

vroulet commented Sep 23, 2024

Indeed, I get the error too. I don't understand why it's not been caught on the tests in github.
Anyway, I'd prefer using small relative accuracies rather than slightly loose absolute differences (the functions that these tests verify are supposed to be use for high precision purposes).
So I'd prefer to use for the lbfgs test:

    chex.assert_trees_all_close(
        plain_precond_vec, expected_precond_vec, rtol=1e-5
    )

and for the linesearch test:

      chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-5)
      chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-5)

I checked, the tests pass with these changes.

@vroulet
Copy link
Collaborator

vroulet commented Sep 23, 2024

The doctest of _dog also fails on my end.
As you are on it, could you also use ellipses on the doctest of dog, i.e., here use

    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...

(This algorithm needs to be completed with a proper evaluation function and fuse with a similar one in the alias/transform files.)

@carlosgmartin carlosgmartin force-pushed the relax_atol_assert_trees_all_close branch 2 times, most recently from 1017127 to ddbd589 Compare September 23, 2024 20:40
Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small linting comment

optax/_src/alias_test.py Outdated Show resolved Hide resolved
@carlosgmartin carlosgmartin force-pushed the relax_atol_assert_trees_all_close branch from ddbd589 to ff12a82 Compare September 23, 2024 20:49
@vroulet vroulet self-requested a review September 23, 2024 23:59
@copybara-service copybara-service bot merged commit fea4745 into google-deepmind:main Sep 25, 2024
8 checks passed
@carlosgmartin carlosgmartin deleted the relax_atol_assert_trees_all_close branch September 25, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants