Skip to content

Commit

Permalink
Update resuming log for missing/unexpected keys
Browse files Browse the repository at this point in the history
  • Loading branch information
khanrc committed Apr 2, 2023
1 parent f204666 commit 762ea71
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
# ------------------------------------------------------------------------------
import os
import re
from collections import defaultdict

import torch
Expand All @@ -15,23 +16,47 @@

from .logger import get_logger

missing_keys_whitelist = [r"clip_.+_encoder\..+", r".+_loss\..+"]


def check_whitelist(key, whitelist):
"""Check whether the given key matches any of the patterns in the whitelist.
"""
for pattern in whitelist:
if re.match(pattern, key):
return True
return False


def load_checkpoint(config, model, optimizer, lr_scheduler, scaler):
logger = get_logger()
logger.info(f"==============> Resuming form {config.checkpoint.resume}....................")
checkpoint = CheckpointLoader.load_checkpoint(config.checkpoint.resume, map_location="cpu")
msg = model.load_state_dict(checkpoint["model"], strict=False)
logger.info(msg)
whitelist_cnt = 0

if msg.missing_keys or msg.unexpected_keys:
logger.info("#" * 80)
if msg.missing_keys:
logger.info("!!! Missing keys !!!")
for key in msg.missing_keys:
if check_whitelist(key, missing_keys_whitelist):
whitelist_cnt += 1
continue

logger.info(f"\t {key}")

logger.info(f"Whitelist of missing keys: {missing_keys_whitelist}")
logger.info(f"# of Whitelisted missing keys = {whitelist_cnt}")
if len(msg.missing_keys) > whitelist_cnt:
logger.warning("Please check the missing keys.")

if msg.unexpected_keys:
logger.info("!!! Unexpected keys !!!")
for key in msg.unexpected_keys:
logger.info(f"\t {key}")

raise ValueError("Unexpected keys are found in the checkpoint.")
logger.info("#" * 80)

metrics = defaultdict(float)
Expand Down

0 comments on commit 762ea71

Please sign in to comment.