Skip to content

Commit

Permalink
fix: handle unexpected keys in weights better
Browse files Browse the repository at this point in the history
  • Loading branch information
brycedrennan committed Jan 3, 2024
1 parent f3f7331 commit 1bf8fb3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions imaginairy/cli/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def imaginairy_click_context(log_level="INFO"):
yield
except errors_to_catch as e:
logger.error(e)
# import traceback
# traceback.print_exc()


def _imagine_cmd(
Expand Down
5 changes: 3 additions & 2 deletions imaginairy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import logging
import numpy as np
import platform
import random
import re
Expand All @@ -9,6 +8,7 @@
from functools import lru_cache
from typing import Any, List, Optional

import numpy as np
import torch
from torch import Tensor, autocast
from torch.nn import functional
Expand Down Expand Up @@ -337,11 +337,12 @@ def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()


def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)
torch.cuda.manual_seed_all(seed=seed)
5 changes: 4 additions & 1 deletion imaginairy/weight_management/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def could_convert(self, source_weights):
def cast_weights(self, source_weights) -> dict[str, "Tensor"]:
converted_state_dict: dict[str, "Tensor"] = {}
for source_key in source_weights:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
except ValueError:
continue
# handle aliases
source_prefix = self.source_aliases.get(source_prefix, source_prefix)
try:
Expand Down

0 comments on commit 1bf8fb3

Please sign in to comment.