Skip to content

Commit

Permalink
convert_dinov2: tweak command-line args
Browse files Browse the repository at this point in the history
i.e. mimic the other conversion scripts
  • Loading branch information
deltheil committed Dec 16, 2023
1 parent 5ca1549 commit 5ce9515
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions scripts/conversion/convert_dinov2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from pathlib import Path

import torch

Expand Down Expand Up @@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:

def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--weights_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help=(
"Official checkpoint from https://github.com/facebookresearch/dinov2"
" e.g. /path/to/dinov2_vits14_pretrain.pth"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()

weights = torch.load(args.weights_path) # type: ignore
weights = torch.load(args.source_path) # type: ignore
convert_dinov2_facebook(weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)


Expand Down

0 comments on commit 5ce9515

Please sign in to comment.