Skip to content

Commit

Permalink
Merge pull request #682 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Mar 6, 2024
2 parents 6b188ba + 68a0d18 commit 0618974
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 17 deletions.
12 changes: 6 additions & 6 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,12 @@ The currently eight fixed workloads are:

| | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation<br>**Target** | Test<br>**Target** | Maximum<br>**Runtime** <br>(in secs) |
|------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------|
| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123649 | 0.126060 | 21,600 |
| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.7344 | 0.741652 | 10,800 |
| **3<br>4** | Image classification | ImageNet | ResNet-50<br>ViT | CE | ER | 0.22569<br>0.22691 | 0.3440<br>0.3481 | 111,600 <br> 111,600 |
| **5<br>6** | Speech recognition | LibriSpeech | Conformer<br>DeepSpeech | CTC | WER | 0.078477<br>0.1162 | 0.046973<br>0.068093 | <br>72,000 |
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,000 |
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 80,000 |
| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 |
| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 |
| **3<br>4** | Image classification | ImageNet | ResNet-50<br>ViT | CE | ER | 0.22569<br>0.22691 | 0.3440<br>0.3481 | 63,008 <br> 77,520 |
| **5<br>6** | Speech recognition | LibriSpeech | Conformer<br>DeepSpeech | CTC | WER | 0.085884<br>0.119936 | 0.052981<br>0.074143 | 61,068<br>55,506 |
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |

#### Randomized workloads

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def initialized(self, key: spec.RandomState,
model: nn.Module) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape))
params_rng, dropout_rng = jax.random.split(key)
variables = jax.jit(
model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape))
model_state, params = variables.pop('params')
return params, model_state

Expand Down
9 changes: 5 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def init_model_fn(
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
initial_variables = jax.jit(self._eval_model.init)(
rng,
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))
params_rng, dropout_rng = jax.random.split(rng)
initial_variables = jax.jit(
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))

initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
Expand Down
20 changes: 17 additions & 3 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,22 @@ def download_criteo1tb(data_dir,
stream=True)

all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip')
with open(all_days_zip_filepath, 'wb') as f:
for chunk in download_request.iter_content(chunk_size=1024):
f.write(chunk)
download = True
if os.path.exists(all_days_zip_filepath):
while True:
overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
all_days_zip_filepath)).lower()
if overwrite in ['y', 'n']:
break
logging.info('Invalid response. Try again.')
if overwrite == 'n':
logging.info(f'Skipping download to {all_days_zip_filepath}')
download = False

if download:
with open(all_days_zip_filepath, 'wb') as f:
for chunk in download_request.iter_content(chunk_size=1024):
f.write(chunk)

unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}'
logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}')
Expand Down Expand Up @@ -679,6 +692,7 @@ def main(_):
if any(s in tmp_dir for s in bad_chars):
raise ValueError(f'Invalid temp_dir: {tmp_dir}.')
data_dir = os.path.abspath(os.path.expanduser(data_dir))
tmp_dir = os.path.abspath(os.path.expanduser(tmp_dir))
logging.info('Downloading data to %s...', data_dir)

if FLAGS.all or FLAGS.criteo1tb:
Expand Down
4 changes: 2 additions & 2 deletions docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \
"wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \
"librispeech_deepspeech" "librispeech_conformer" "mnist" \
"conformer_layernorm" "conformer_attention_temperature" \
"conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
"librispeech_conformer_layernorm" "librispeech_conformer_attention_temperature" \
"librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
"librispeech_deepspeech_tanh" \
"librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug"
"fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def get_batch_size(workload_name):
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'imagenet_vit_glu':
return 512
elif workload_name == 'librispeech_conformer':
return 256
elif workload_name == 'librispeech_deepspeech':
Expand Down
4 changes: 3 additions & 1 deletion submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def train_once(
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])

goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time
Expand Down

0 comments on commit 0618974

Please sign in to comment.