Skip to content

Commit

Permalink
Merge branch 'main' into timm_update
Browse files Browse the repository at this point in the history
  • Loading branch information
EIFY authored Aug 28, 2023
2 parents 83b7895 + 579b6a9 commit ea6a9f2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ numerical results as the naïve method.
#### Epochs
For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
For larger datasets (eg Laion2B), we recommend setting `--train-num-samples` to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with `--dataset-resampled` to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
#### Patch Dropout
Expand All @@ -196,7 +196,7 @@ In the paper, they also finetuned without the patch dropout at the end. You can
#### Multiple data sources
OpenCLIP supports using multiple data sources, by separating different data paths with `::`.
For instance, to train on CC12M and on LAION, one might use `--train-data '/data/cc12m/cc12m-train-{0000..2175}.tar'::/data/LAION-400M/{00000..41455}.tar"`.
For instance, to train on CC12M and on LAION, one might use `--train-data "/data/cc12m/cc12m-train-{0000..2175}.tar::/data/LAION-400M/{00000..41455}.tar"`.
Using `--dataset-resampled` is recommended for these cases.
By default, on expectation the amount of times the model will see a sample from each source is proportional to the size of the source.
Expand Down
4 changes: 4 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def load_checkpoint(model, checkpoint_path, strict=True):
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
Expand Down
2 changes: 1 addition & 1 deletion src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def main(args):
output_dict=True,
)
if args.distill:
# FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms.
# FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms.
dist_model, _, _ = create_model_and_transforms(
args.distill_model,
args.distill_pretrained,
Expand Down

0 comments on commit ea6a9f2

Please sign in to comment.