From 1bf8fb3cbceaa1f4cb960dfbb914cd4f9df36e34 Mon Sep 17 00:00:00 2001 From: Bryce Date: Tue, 2 Jan 2024 08:56:17 -0800 Subject: [PATCH] fix: handle unexpected keys in weights better --- imaginairy/cli/shared.py | 2 ++ imaginairy/utils/__init__.py | 5 +++-- imaginairy/weight_management/conversion.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 0ccd55a9..220ddb46 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -23,6 +23,8 @@ def imaginairy_click_context(log_level="INFO"): yield except errors_to_catch as e: logger.error(e) + # import traceback + # traceback.print_exc() def _imagine_cmd( diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index cd4146df..a41cb81a 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -1,6 +1,5 @@ import importlib import logging -import numpy as np import platform import random import re @@ -9,6 +8,7 @@ from functools import lru_cache from typing import Any, List, Optional +import numpy as np import torch from torch import Tensor, autocast from torch.nn import functional @@ -337,6 +337,7 @@ def clear_gpu_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() + def seed_everything(seed: int | None = None) -> None: if seed is None: seed = random.randint(0, 2**32 - 1) @@ -344,4 +345,4 @@ def seed_everything(seed: int | None = None) -> None: random.seed(a=seed) np.random.seed(seed=seed) torch.manual_seed(seed=seed) - torch.cuda.manual_seed_all(seed=seed) \ No newline at end of file + torch.cuda.manual_seed_all(seed=seed) diff --git a/imaginairy/weight_management/conversion.py b/imaginairy/weight_management/conversion.py index b26e3195..14670683 100644 --- a/imaginairy/weight_management/conversion.py +++ b/imaginairy/weight_management/conversion.py @@ -74,7 +74,10 @@ def could_convert(self, source_weights): def cast_weights(self, source_weights) -> dict[str, "Tensor"]: converted_state_dict: dict[str, "Tensor"] = {} for source_key in source_weights: - source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1) + try: + source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1) + except ValueError: + continue # handle aliases source_prefix = self.source_aliases.get(source_prefix, source_prefix) try: