diff --git a/flax/struct.py b/flax/struct.py index f434593f7c..aeafdf4c15 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -86,7 +86,7 @@ class DirectionAndScaleKernel: @classmethod def create(cls, kernel): scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) - directin = direction / scale + direction = direction / scale return cls(direction, scale) Args: