diff --git a/flax/core/scope.py b/flax/core/scope.py index 910fd42c32..f9f98ccd09 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -34,6 +34,7 @@ Tuple, TypeVar, Union, + cast, overload, ) @@ -892,7 +893,11 @@ def variable( raise errors.ScopeVariableNotFoundError(name, col, self.path_text) init_value = init_fn(*init_args) self.put_variable(col, name, init_value) - return Variable(self, col, name, unbox=unbox) + # cast to make static analyzers happy + return cast( + Union[Variable[T], Variable[meta.AxisMetadata[T]]], + Variable(self, col, name, unbox=unbox), + ) @overload def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 3ecdc72ff1..2c6400be07 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,7 +15,7 @@ """Attention core modules for Flax.""" import functools -from typing import (Any, Callable, Optional, Tuple) +from typing import (Any, Callable, Optional, Tuple, Union) from flax.linen.dtypes import promote_dtype from flax.linen import initializers @@ -334,7 +334,11 @@ def __call__( ) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + indices: tuple[Union[int, jax.Array], ...] = (0,) * len(batch_dims) + ( + cur_index, + 0, + 0, + ) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key