Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat default dropout in doc #806

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,19 @@ The currently eight fixed workloads are:
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |

Default Dropout Values for Different Workloads:

| Workload | Dropout Values |
|------------------------|------------------------------------------------------------------------------------------------------|
| criteo 1tb | dropout_rate: 0.0 |
| fastmri | dropout_rate: 0.0 |
| imagenet_resnet | dropout not used |
| imagenet_vit | dropout_rate: 0.0 |
| librispeech_conformer | attention_dropout_rate: 0.0 <br> attention_residual_dropout_rate: 0.1 <br> conv_residual_dropout_rate: 0.0 <br> feed_forward_dropout_rate: 0.0 <br> feed_forward_residual_dropout_rate: 0.1 <br> input_dropout_rate: 0.1 |
| librispeech_deepspeech | input_dropout_rate: 0.1 <br> feed_forward_dropout_rate: 0.1 <br> (Only for JAX - dropout_rate in CudnnLSTM class: 0.0) |
| ogbg | dropout_rate: 0.1 |
| wmt | dropout_rate: 0.1 <br> attention_dropout_rate: 0.1 |

#### Randomized workloads

In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration.
Expand Down
7 changes: 5 additions & 2 deletions scoring/performance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def compute_performance_profiles(submissions,
scale='linear',
verbosity=0,
strict=False,
self_tuning_ruleset=False):
self_tuning_ruleset=False,
output_dir=None):
"""Compute performance profiles for a set of submission by some time column.

Args:
Expand Down Expand Up @@ -320,7 +321,9 @@ def compute_performance_profiles(submissions,
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
# Sort workloads alphabetically (for better display)
df = df.reindex(sorted(df.columns), axis=1)


# Save time to target dataframe
df.to_csv(os.path.join(output_dir, 'time_to_targets.csv'))
# For each held-out workload set to inf if the base workload is inf or nan
for workload in df.keys():
if workload not in BASE_WORKLOADS:
Expand Down
3 changes: 2 additions & 1 deletion scoring/score_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def main(_):
scale='linear',
verbosity=0,
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
strict=FLAGS.strict)
strict=FLAGS.strict,
output_dir=FLAGS.output_dir,)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ jax_core_deps =
chex==0.1.7
ml_dtypes==0.2.0
protobuf==4.25.3
scipy==1.11.4


# JAX CPU
Expand Down
Loading