Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remaining pyink issues #3219

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def main(_):
)
issues.get()

df_issues = df_issues0 = pd.DataFrame(list(_get_issues_features(issues.raw_data)))
df_issues = df_issues0 = pd.DataFrame(
list(_get_issues_features(issues.raw_data))
)
df_issues['issue_response_time'] = (
df_issues['time_labeled_or_converted'] - df_issues['created_at']
)
Expand All @@ -350,7 +352,9 @@ def main(_):
prs.get()

df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data)))
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(axis=1)
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(
axis=1
)
df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at']
df_prs['pr_resolution_time'] = (
df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at']
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ repos:
hooks:
- id: jupytext
args: [--sync]
- repo: https://github.com/google/pyink
rev: 23.5.0
hooks:
- id: pyink
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs/.*\.md$|^dev/.*\.py$
exclude: ^docs/.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand All @@ -29,7 +33,3 @@ repos:
--extra-keys,
"metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab",
]
- repo: https://github.com/google/pyink
rev: 23.5.0
hooks:
- id: pyink
2 changes: 2 additions & 0 deletions dev/update_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
Alternatively, the list can also be provided from the local environment with:

python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"

This will use the currently installed versions of all packages.
"""

import pathlib
Expand Down
4 changes: 3 additions & 1 deletion examples/imagenet/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def test_create_model_local(self):

Uses smaller inputs than `test_create_model` to due to higher compute.
"""
model = train.create_model(model_cls=models._ResNet1Local, half_precision=False) # pylint: disable=protected-access
model = train.create_model(
model_cls=models._ResNet1Local, half_precision=False
) # pylint: disable=protected-access
params, batch_stats = train.initialized(random.PRNGKey(0), 64, model)
variables = {'params': params, 'batch_stats': batch_stats}
x = random.normal(random.PRNGKey(1), (1, 64, 64, 3))
Expand Down
4 changes: 3 additions & 1 deletion examples/lm1b/temperature_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def tokens_to_logits(tokens, cache):
logits = logits.squeeze(axis=1)
return logits, cache

new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5)
new_tokens = temperature_sample(
tokens, cache, tokens_to_logits, key, topk=5
)

np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]])

Expand Down
12 changes: 9 additions & 3 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -525,7 +527,9 @@ def encode_strings(strs, max_len):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -542,7 +546,9 @@ def encode_strings(strs, max_len):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary["perplexity"] = jnp.clip(
jnp.exp(summary["loss"]), a_max=1.0e4
Expand Down
14 changes: 8 additions & 6 deletions examples/nlp_seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@

FLAGS = flags.FLAGS

flags.DEFINE_string('model_dir', default='', help=('Directory for model data.'))
flags.DEFINE_string('model_dir', default='', help='Directory for model data.')

flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.'))
flags.DEFINE_string('experiment', default='xpos', help='Experiment name.')

flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.')

Expand All @@ -58,7 +58,7 @@
'num_train_steps', default=75000, help='Number of train steps.'
)

flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.'))
flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.')

flags.DEFINE_float(
'weight_decay',
Expand All @@ -74,9 +74,9 @@
'random_seed', default=0, help='Integer for PRNG random seed.'
)

flags.DEFINE_string('train', default='', help=('Path to training data.'))
flags.DEFINE_string('train', default='', help='Path to training data.')

flags.DEFINE_string('dev', default='', help=('Path to development data.'))
flags.DEFINE_string('dev', default='', help='Path to development data.')


def create_learning_rate_scheduler(
Expand Down Expand Up @@ -354,7 +354,9 @@ def eval_step(params, batch):
tick = time.time()
best_dev_score = 0
for step, batch in zip(range(num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_util.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
batch = common_utils.shard(
jax.tree_util.tree_map(lambda x: x._numpy(), batch)
) # pylint: disable=protected-access

state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)
Expand Down
8 changes: 6 additions & 2 deletions examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def get_experience(
sim_state = sim.conn.recv()
sim_states.append(sim_state)
sim_states = np.concatenate(sim_states, axis=0)
log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states)
log_probs, values = agent.policy_action(
state.apply_fn, state.params, sim_states
)
log_probs, values = jax.device_get((log_probs, values))
probs = np.exp(np.array(log_probs))
for i, sim in enumerate(simulators):
Expand Down Expand Up @@ -343,7 +345,9 @@ def train(
score = test_episodes.policy_test(1, state.apply_fn, state.params, game)
frames = step * config.num_agents * config.actor_steps
summary_writer.scalar('game_score', score, frames)
logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score)
logging.info(
'Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score
)

# Core training code.
alpha = (
Expand Down
4 changes: 3 additions & 1 deletion examples/seq2seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __call__(
encoding format).
"""
# Encode inputs.
encoder = nn.RNN(nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder')
encoder = nn.RNN(
nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder'
)
decoder = nn.RNN(
DecoderLSTMCell(
decoder_inputs.shape[-1], self.teacher_force, self.vocab_size
Expand Down
8 changes: 6 additions & 2 deletions examples/sst2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ def setup(self):

def __call__(self, embedded_inputs, lengths):
# Forward LSTM.
initial_state = self.forward_lstm.initialize_carry(embedded_inputs[:, 0].shape)
initial_state = self.forward_lstm.initialize_carry(
embedded_inputs[:, 0].shape
)
_, forward_outputs = self.forward_lstm(initial_state, embedded_inputs)

# Backward LSTM.
reversed_inputs = flip_sequences(embedded_inputs, lengths)
initial_state = self.backward_lstm.initialize_carry(reversed_inputs[:, 0].shape)
initial_state = self.backward_lstm.initialize_carry(
reversed_inputs[:, 0].shape
)
_, backward_outputs = self.backward_lstm(initial_state, reversed_inputs)
backward_outputs = flip_sequences(backward_outputs, lengths)

Expand Down
20 changes: 12 additions & 8 deletions examples/wmt/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_util.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree_util.tree_map(
lambda x: add_beam_dim(x, beam_size), cache
)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -350,13 +352,15 @@ def beam_search_loop_body_fn(state):
[state.finished_flags, newly_finished], axis=1
)
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)
(
top_finished_seq,
top_finished_scores,
top_finished_flags,
) = gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)

return BeamState(
Expand Down
12 changes: 9 additions & 3 deletions examples/wmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -626,7 +628,9 @@ def decode_tokens(toks):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -643,7 +647,9 @@ def decode_tokens(toks):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary = {"train_" + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand Down
4 changes: 3 additions & 1 deletion flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def body_fn(c, xs, init_mode=False):
'broadcasted variable has a data dependency on the scan body.'
)
out_flat.append(const)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(out_tree(), out_flat)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(
out_tree(), out_flat
)

c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
Expand Down
4 changes: 3 additions & 1 deletion flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def body(mdl, c):

value: Any
names: LogicalNames = struct.field(pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(default=None, pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> Any:
"""Returns the wrapped value with the partitioning applied as a sharding constraint."""
Expand Down
13 changes: 9 additions & 4 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self):
class FlaxError(Exception):

def __init__(self, message):
error_page = 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
error_page = (
'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
)
module_name = self.__class__.__module__
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
Expand Down Expand Up @@ -706,7 +708,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself'
'Can only call init, init_with_output or apply methods on an instance'
' of the Module class, not the Module class itself'
)


Expand Down Expand Up @@ -755,7 +758,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Trying to access a property that is accessing a non-existent attribute.'
'Trying to access a property that is accessing a non-existent'
' attribute.'
)


Expand All @@ -770,7 +774,8 @@ class InvalidCheckpointError(FlaxError):

def __init__(self, path, step):
super().__init__(
f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".'
f'Trying to save an outdated checkpoint at step: "{step}" and path:'
f' "{path}".'
)


Expand Down
4 changes: 3 additions & 1 deletion flax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def transpose_out(x):

def body_wrapper(c, xs):
if keepdims:
xs = jax.tree_util.tree_map(lambda x: x.reshape((1,) * len(axis) + x.shape), xs)
xs = jax.tree_util.tree_map(
lambda x: x.reshape((1,) * len(axis) + x.shape), xs
)
xs = jax.tree_util.tree_map(transpose_out, xs)
c, ys = body_fn(c, xs)
if keepdims:
Expand Down
13 changes: 9 additions & 4 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ class MultiHeadDotProductAttention(Module):
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
Expand Down Expand Up @@ -316,9 +318,12 @@ def __call__(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = (
cached_key.value.shape
)
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
Expand Down
Loading