Skip to content

Commit

Permalink
Merge pull request #3219 from google:fix-main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550069914
  • Loading branch information
Flax Authors committed Jul 21, 2023
2 parents 3ab11fe + c90c267 commit 4f1884c
Show file tree
Hide file tree
Showing 43 changed files with 376 additions and 145 deletions.
8 changes: 6 additions & 2 deletions .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def main(_):
)
issues.get()

df_issues = df_issues0 = pd.DataFrame(list(_get_issues_features(issues.raw_data)))
df_issues = df_issues0 = pd.DataFrame(
list(_get_issues_features(issues.raw_data))
)
df_issues['issue_response_time'] = (
df_issues['time_labeled_or_converted'] - df_issues['created_at']
)
Expand All @@ -350,7 +352,9 @@ def main(_):
prs.get()

df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data)))
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(axis=1)
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(
axis=1
)
df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at']
df_prs['pr_resolution_time'] = (
df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at']
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ repos:
hooks:
- id: jupytext
args: [--sync]
- repo: https://github.com/google/pyink
rev: 23.5.0
hooks:
- id: pyink
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs/.*\.md$|^dev/.*\.py$
exclude: ^docs/.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand All @@ -29,7 +33,3 @@ repos:
--extra-keys,
"metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab",
]
- repo: https://github.com/google/pyink
rev: 23.5.0
hooks:
- id: pyink
2 changes: 2 additions & 0 deletions dev/update_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
Alternatively, the list can also be provided from the local environment with:
python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"
This will use the currently installed versions of all packages.
"""

import pathlib
Expand Down
4 changes: 3 additions & 1 deletion examples/imagenet/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def test_create_model_local(self):
Uses smaller inputs than `test_create_model` to due to higher compute.
"""
model = train.create_model(model_cls=models._ResNet1Local, half_precision=False) # pylint: disable=protected-access
model = train.create_model(
model_cls=models._ResNet1Local, half_precision=False
) # pylint: disable=protected-access
params, batch_stats = train.initialized(random.PRNGKey(0), 64, model)
variables = {'params': params, 'batch_stats': batch_stats}
x = random.normal(random.PRNGKey(1), (1, 64, 64, 3))
Expand Down
4 changes: 3 additions & 1 deletion examples/lm1b/temperature_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def tokens_to_logits(tokens, cache):
logits = logits.squeeze(axis=1)
return logits, cache

new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5)
new_tokens = temperature_sample(
tokens, cache, tokens_to_logits, key, topk=5
)

np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]])

Expand Down
12 changes: 9 additions & 3 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -525,7 +527,9 @@ def encode_strings(strs, max_len):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -542,7 +546,9 @@ def encode_strings(strs, max_len):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary["perplexity"] = jnp.clip(
jnp.exp(summary["loss"]), a_max=1.0e4
Expand Down
14 changes: 8 additions & 6 deletions examples/nlp_seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@

FLAGS = flags.FLAGS

flags.DEFINE_string('model_dir', default='', help=('Directory for model data.'))
flags.DEFINE_string('model_dir', default='', help='Directory for model data.')

flags.DEFINE_string('experiment', default='xpos', help=('Experiment name.'))
flags.DEFINE_string('experiment', default='xpos', help='Experiment name.')

flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.')

Expand All @@ -58,7 +58,7 @@
'num_train_steps', default=75000, help='Number of train steps.'
)

flags.DEFINE_float('learning_rate', default=0.05, help=('Learning rate.'))
flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.')

flags.DEFINE_float(
'weight_decay',
Expand All @@ -74,9 +74,9 @@
'random_seed', default=0, help='Integer for PRNG random seed.'
)

flags.DEFINE_string('train', default='', help=('Path to training data.'))
flags.DEFINE_string('train', default='', help='Path to training data.')

flags.DEFINE_string('dev', default='', help=('Path to development data.'))
flags.DEFINE_string('dev', default='', help='Path to development data.')


def create_learning_rate_scheduler(
Expand Down Expand Up @@ -354,7 +354,9 @@ def eval_step(params, batch):
tick = time.time()
best_dev_score = 0
for step, batch in zip(range(num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_util.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
batch = common_utils.shard(
jax.tree_util.tree_map(lambda x: x._numpy(), batch)
) # pylint: disable=protected-access

state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)
Expand Down
8 changes: 6 additions & 2 deletions examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def get_experience(
sim_state = sim.conn.recv()
sim_states.append(sim_state)
sim_states = np.concatenate(sim_states, axis=0)
log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states)
log_probs, values = agent.policy_action(
state.apply_fn, state.params, sim_states
)
log_probs, values = jax.device_get((log_probs, values))
probs = np.exp(np.array(log_probs))
for i, sim in enumerate(simulators):
Expand Down Expand Up @@ -343,7 +345,9 @@ def train(
score = test_episodes.policy_test(1, state.apply_fn, state.params, game)
frames = step * config.num_agents * config.actor_steps
summary_writer.scalar('game_score', score, frames)
logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score)
logging.info(
'Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score
)

# Core training code.
alpha = (
Expand Down
4 changes: 3 additions & 1 deletion examples/seq2seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __call__(
encoding format).
"""
# Encode inputs.
encoder = nn.RNN(nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder')
encoder = nn.RNN(
nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder'
)
decoder = nn.RNN(
DecoderLSTMCell(
decoder_inputs.shape[-1], self.teacher_force, self.vocab_size
Expand Down
8 changes: 6 additions & 2 deletions examples/sst2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ def setup(self):

def __call__(self, embedded_inputs, lengths):
# Forward LSTM.
initial_state = self.forward_lstm.initialize_carry(embedded_inputs[:, 0].shape)
initial_state = self.forward_lstm.initialize_carry(
embedded_inputs[:, 0].shape
)
_, forward_outputs = self.forward_lstm(initial_state, embedded_inputs)

# Backward LSTM.
reversed_inputs = flip_sequences(embedded_inputs, lengths)
initial_state = self.backward_lstm.initialize_carry(reversed_inputs[:, 0].shape)
initial_state = self.backward_lstm.initialize_carry(
reversed_inputs[:, 0].shape
)
_, backward_outputs = self.backward_lstm(initial_state, reversed_inputs)
backward_outputs = flip_sequences(backward_outputs, lengths)

Expand Down
20 changes: 12 additions & 8 deletions examples/wmt/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_util.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree_util.tree_map(
lambda x: add_beam_dim(x, beam_size), cache
)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -350,13 +352,15 @@ def beam_search_loop_body_fn(state):
[state.finished_flags, newly_finished], axis=1
)
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)
(
top_finished_seq,
top_finished_scores,
top_finished_flags,
) = gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)

return BeamState(
Expand Down
12 changes: 9 additions & 3 deletions examples/wmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -626,7 +628,9 @@ def decode_tokens(toks):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -643,7 +647,9 @@ def decode_tokens(toks):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary = {"train_" + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand Down
4 changes: 3 additions & 1 deletion flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def body_fn(c, xs, init_mode=False):
'broadcasted variable has a data dependency on the scan body.'
)
out_flat.append(const)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(out_tree(), out_flat)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(
out_tree(), out_flat
)

c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
Expand Down
4 changes: 3 additions & 1 deletion flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def body(mdl, c):

value: Any
names: LogicalNames = struct.field(pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(default=None, pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> Any:
"""Returns the wrapped value with the partitioning applied as a sharding constraint."""
Expand Down
13 changes: 9 additions & 4 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self):
class FlaxError(Exception):

def __init__(self, message):
error_page = 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
error_page = (
'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
)
module_name = self.__class__.__module__
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
Expand Down Expand Up @@ -706,7 +708,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself'
'Can only call init, init_with_output or apply methods on an instance'
' of the Module class, not the Module class itself'
)


Expand Down Expand Up @@ -755,7 +758,8 @@ def __call__(self, x):

def __init__(self):
super().__init__(
'Trying to access a property that is accessing a non-existent attribute.'
'Trying to access a property that is accessing a non-existent'
' attribute.'
)


Expand All @@ -770,7 +774,8 @@ class InvalidCheckpointError(FlaxError):

def __init__(self, path, step):
super().__init__(
f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".'
f'Trying to save an outdated checkpoint at step: "{step}" and path:'
f' "{path}".'
)


Expand Down
4 changes: 3 additions & 1 deletion flax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def transpose_out(x):

def body_wrapper(c, xs):
if keepdims:
xs = jax.tree_util.tree_map(lambda x: x.reshape((1,) * len(axis) + x.shape), xs)
xs = jax.tree_util.tree_map(
lambda x: x.reshape((1,) * len(axis) + x.shape), xs
)
xs = jax.tree_util.tree_map(transpose_out, xs)
c, ys = body_fn(c, xs)
if keepdims:
Expand Down
13 changes: 9 additions & 4 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ class MultiHeadDotProductAttention(Module):
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
initializers.zeros_init()
)
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
Expand Down Expand Up @@ -316,9 +318,12 @@ def __call__(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = (
cached_key.value.shape
)
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
Expand Down
Loading

0 comments on commit 4f1884c

Please sign in to comment.