diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index f57fe251d2..65d6053acc 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1837,9 +1837,26 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] + attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k] for block in attention_blocks: + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # projection + new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") + new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") + + # fix the cross attention blocks + cross_attention_blocks = [ + k.replace(".out_proj.weight", "") + for k in new_state_dict + if "out_proj.weight" in k and "transformer_blocks" in k + ] + for block in cross_attention_blocks: new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")