diff --git a/README.md b/README.md index aa3c975b5..c25c07db2 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ numerical results as the naïve method. #### Epochs -For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often. +For larger datasets (eg Laion2B), we recommend setting `--train-num-samples` to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with `--dataset-resampled` to do sampling with replacement. This allows having frequent checkpoints to evaluate more often. #### Patch Dropout @@ -196,7 +196,7 @@ In the paper, they also finetuned without the patch dropout at the end. You can #### Multiple data sources OpenCLIP supports using multiple data sources, by separating different data paths with `::`. -For instance, to train on CC12M and on LAION, one might use `--train-data '/data/cc12m/cc12m-train-{0000..2175}.tar'::/data/LAION-400M/{00000..41455}.tar"`. +For instance, to train on CC12M and on LAION, one might use `--train-data "/data/cc12m/cc12m-train-{0000..2175}.tar::/data/LAION-400M/{00000..41455}.tar"`. Using `--dataset-resampled` is recommended for these cases. By default, on expectation the amount of times the model will see a sample from each source is proportional to the size of the source. diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ac8596eab..72a4e4d18 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -100,6 +100,10 @@ def load_checkpoint(model, checkpoint_path, strict=True): # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # Certain text transformers no longer expect position_ids after transformers==4.31 + position_id_key = 'text.transformer.embeddings.position_ids' + if position_id_key in state_dict and not hasattr(model, position_id_key): + del state_dict[position_id_key] resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys diff --git a/src/training/main.py b/src/training/main.py index 2929d0121..4f2172808 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -232,7 +232,7 @@ def main(args): output_dict=True, ) if args.distill: - # FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. + # FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms. dist_model, _, _ = create_model_and_transforms( args.distill_model, args.distill_pretrained,