From d6e2555107d0cbc2a804e24f60fb756f447adefa Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Sun, 3 Mar 2024 19:32:10 -0500 Subject: [PATCH 01/12] Update BS --- .../target_setting_algorithms/get_batch_size.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/reference_algorithms/target_setting_algorithms/get_batch_size.py b/reference_algorithms/target_setting_algorithms/get_batch_size.py index 4c7c96241..3bdc90f36 100644 --- a/reference_algorithms/target_setting_algorithms/get_batch_size.py +++ b/reference_algorithms/target_setting_algorithms/get_batch_size.py @@ -15,6 +15,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'imagenet_vit': return 1024 + elif workload_name == 'imagenet_vit_glu': + return 512 elif workload_name == 'librispeech_conformer': return 256 elif workload_name == 'librispeech_deepspeech': From e31f2cac8991fb0fb7b9ebdc9a7fd0861a2bf956 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 4 Mar 2024 15:45:58 +0000 Subject: [PATCH 02/12] Add missing dropout rng for ImageNet-ViT and WMT --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 8 ++++++-- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 0bc1cd8cc..a9ac76ea2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -21,10 +21,14 @@ # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def initialized(self, key: spec.RandomState, + def initialized(self, + key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape)) + params_rng, dropout_rng = jax.random.split(key) + variables = jax.jit( + model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape)) model_state, params = variables.pop('params') return params, model_state diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 0250206a6..1b6d6449e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -234,8 +234,9 @@ def init_model_fn( self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) + params_rng, dropout_rng = jax.random.split(rng) initial_variables = jax.jit(self._eval_model.init)( - rng, + {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) From 1d0e8a9067a4c7443b501b1652fe3ec1c6385cee Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 4 Mar 2024 15:50:02 +0000 Subject: [PATCH 03/12] Fix yapf --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 3 +-- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index a9ac76ea2..93f603bca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -21,8 +21,7 @@ # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def initialized(self, - key: spec.RandomState, + def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) params_rng, dropout_rng = jax.random.split(key) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 1b6d6449e..b10d4056d 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -235,10 +235,10 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) - initial_variables = jax.jit(self._eval_model.init)( - {'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + initial_variables = jax.jit( + self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) From 59734a471657ae01a19635bf07db1f6992f3dbd5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 4 Mar 2024 22:35:40 -0800 Subject: [PATCH 04/12] Fix to submission_runner.py Fix stopping condition for training. --- submission_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index fad60e48b..7e3e8d5e4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -387,7 +387,10 @@ def train_once( train_state['test_goal_reached'] = ( workload.has_reached_test_target(latest_eval_result) or train_state['test_goal_reached']) - + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) + # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time From ecac9c9e34b3bef1672ca90a6b6058e6604f845b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 4 Mar 2024 22:38:48 -0800 Subject: [PATCH 05/12] remove trailing whitespaces --- submission_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 7e3e8d5e4..6381c97f1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -390,7 +390,6 @@ def train_once( goals_reached = ( train_state['validation_goal_reached'] and train_state['test_goal_reached']) - # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time From b14e05496d425301f8d8f52923d0228b2d279f07 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 4 Mar 2024 23:00:55 -0800 Subject: [PATCH 06/12] Expand ~ for tmp_dir arg --- datasets/dataset_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 1d3e5d2c7..d556d7c04 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -679,6 +679,7 @@ def main(_): if any(s in tmp_dir for s in bad_chars): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') data_dir = os.path.abspath(os.path.expanduser(data_dir)) + tmp_dir = os.path.abspath(os.path.expanduser(data_dir)) logging.info('Downloading data to %s...', data_dir) if FLAGS.all or FLAGS.criteo1tb: From 6a2fc60f0399542f5f8ba9e9c18268e6728450d2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Mon, 4 Mar 2024 23:19:56 -0800 Subject: [PATCH 07/12] prompt user for criteo download --- datasets/dataset_setup.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index d556d7c04..24f7bf9cf 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -291,9 +291,22 @@ def download_criteo1tb(data_dir, stream=True) all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip') - with open(all_days_zip_filepath, 'wb') as f: - for chunk in download_request.iter_content(chunk_size=1024): - f.write(chunk) + download = True + if os.path.exists(all_days_zip_filepath): + while True: + overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( + all_days_zip_filepath)).lower() + if overwrite in ['y', 'n']: + break + logging.info('Invalid response. Try again.') + if overwrite == 'n': + logging.info(f'Skipping download URL {url} to {file_path}') + download = False + + if download: + with open(all_days_zip_filepath, 'wb') as f: + for chunk in download_request.iter_content(chunk_size=1024): + f.write(chunk) unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}' logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}') From c6072b7139668168450179484a27adbbbf9215a7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 5 Mar 2024 07:30:27 +0000 Subject: [PATCH 08/12] formatting fix --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 24f7bf9cf..b8c798157 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -302,8 +302,8 @@ def download_criteo1tb(data_dir, if overwrite == 'n': logging.info(f'Skipping download URL {url} to {file_path}') download = False - - if download: + + if download: with open(all_days_zip_filepath, 'wb') as f: for chunk in download_request.iter_content(chunk_size=1024): f.write(chunk) From 1c7a7486f38852495be785a630aea1e1b51631d6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 5 Mar 2024 07:33:57 +0000 Subject: [PATCH 09/12] fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index b8c798157..a171bc0c1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -300,7 +300,7 @@ def download_criteo1tb(data_dir, break logging.info('Invalid response. Try again.') if overwrite == 'n': - logging.info(f'Skipping download URL {url} to {file_path}') + logging.info(f'Skipping download to {all_days_zip_filepath}') download = False if download: From 3e4e95fd3bd4176cd923afad13bb3f28fa7a5cc6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 5 Mar 2024 07:37:04 +0000 Subject: [PATCH 10/12] fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a171bc0c1..ef8667660 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -692,7 +692,7 @@ def main(_): if any(s in tmp_dir for s in bad_chars): raise ValueError(f'Invalid temp_dir: {tmp_dir}.') data_dir = os.path.abspath(os.path.expanduser(data_dir)) - tmp_dir = os.path.abspath(os.path.expanduser(data_dir)) + tmp_dir = os.path.abspath(os.path.expanduser(tmp_dir)) logging.info('Downloading data to %s...', data_dir) if FLAGS.all or FLAGS.criteo1tb: From 596274f530e6497b927a602df7aa9a74d60b7e05 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 5 Mar 2024 22:34:30 +0000 Subject: [PATCH 11/12] Fix workload targets and max runtimes in docs --- DOCUMENTATION.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 6037f285f..821902003 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,12 +400,12 @@ The currently eight fixed workloads are: | | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation
**Target** | Test
**Target** | Maximum
**Runtime**
(in secs) | |------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------| -| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123649 | 0.126060 | 21,600 | -| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.7344 | 0.741652 | 10,800 | -| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 111,600
111,600 | -| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.078477
0.1162 | 0.046973
0.068093 |
72,000 | -| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,000 | -| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 80,000 | +| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 | +| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 | +| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 63,008
77,520 | +| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 61,068
55,506 | +| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | +| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | #### Randomized workloads From 7a3a2276ccfc991975f872f287f71cc079f699a9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 5 Mar 2024 23:30:58 +0000 Subject: [PATCH 12/12] fix names in startup.sh for conformer variants --- docker/scripts/startup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index e9baf2463..e5f059772 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -157,8 +157,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \ "librispeech_deepspeech" "librispeech_conformer" "mnist" \ - "conformer_layernorm" "conformer_attention_temperature" \ - "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ + "librispeech_conformer_layernorm" "librispeech_conformer_attention_temperature" \ + "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size")