Skip to content

Commit

Permalink
fix mypy + add .git-blame-ignore-revs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 21, 2023
1 parent 40a6e07 commit 97d038c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# apply pyink
40a6e074e5224d733f964be00e21e0a1cb98bd2e
4 changes: 2 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,10 @@ class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""

@compact
def __call__(
def __call__( # type: ignore
self,
inputs_q: Array,
mask: Optional[Array] = None, # type: ignore
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
"""Applies multi-head dot product self-attention on the input data.
Expand Down
14 changes: 6 additions & 8 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,11 @@ def _customized_dataclass_transform(cls, kw_only: bool):
for name, annotation, default in extra_fields: # pytype: disable=invalid-annotation
setattr(cls, name, default)
cls.__annotations__[name] = annotation
dataclasses.dataclass(
dataclasses.dataclass( # type: ignore[call-overload]
unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
kw_only=True,
)(
cls
) # type: ignore[call-overload]
)(cls)
else:
raise TypeError('`kw_only` is not available before Py 3.10.')
else:
Expand Down Expand Up @@ -1900,8 +1898,8 @@ def sow(
name: str,
value: T,
reduce_fn: Callable[[K, T], K] = tuple_reduce,
init_fn: Callable[[], K] = tuple_init,
) -> bool: # type: ignore
init_fn: Callable[[], K] = tuple_init, # type: ignore
) -> bool:
...

def sow(
Expand All @@ -1910,8 +1908,8 @@ def sow(
name: str,
value: T,
reduce_fn: Callable[[K, T], K] = tuple_reduce,
init_fn: Callable[[], K] = tuple_init,
) -> bool: # type: ignore
init_fn: Callable[[], K] = tuple_init, # type: ignore
) -> bool:
"""Stores a value in a collection.
Collections can be used to collect intermediate values without
Expand Down
20 changes: 8 additions & 12 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,16 @@ def _concat_dense(
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name=f'i{component}',
)(
inputs
) # type: ignore[call-arg]
name=f'i{component}', # type: ignore[call-arg]
)(inputs)
dense_params_h[component] = DenseParams(
features=hidden_features,
use_bias=True,
param_dtype=self.param_dtype,
kernel_init=self.recurrent_kernel_init,
bias_init=self.bias_init,
name=f'h{component}',
)(
h
) # type: ignore[call-arg]
name=f'h{component}', # type: ignore[call-arg]
)(h)
dense_h = _concat_dense(h, dense_params_h, use_bias=True)
dense_i = _concat_dense(inputs, dense_params_i, use_bias=False)

Expand Down Expand Up @@ -809,8 +805,8 @@ def __call__(
if reverse:
inputs = jax.tree_map(
lambda x: flip_sequences(
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major
), # type: ignore
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major # type: ignore
),
inputs,
)

Expand Down Expand Up @@ -867,8 +863,8 @@ def scan_fn(
if reverse and keep_order:
outputs = jax.tree_map(
lambda x: flip_sequences(
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major
), # type: ignore
x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major # type: ignore
),
outputs,
)

Expand Down

0 comments on commit 97d038c

Please sign in to comment.