From 1b1569d2770fc2fb27dd24bbd558c1604dd220c5 Mon Sep 17 00:00:00 2001 From: david Date: Thu, 30 May 2024 15:27:04 +0800 Subject: [PATCH] Add model cache after loaded --- comfy/cache.py | 94 +++++++++++++++++++++++++++++++++++++++ comfy/model_management.py | 21 ++++++++- comfy/sd.py | 85 ++++++++++++++++++++++------------- comfy/utils.py | 9 ++++ 4 files changed, 176 insertions(+), 33 deletions(-) create mode 100644 comfy/cache.py diff --git a/comfy/cache.py b/comfy/cache.py new file mode 100644 index 00000000000..af60e2193bb --- /dev/null +++ b/comfy/cache.py @@ -0,0 +1,94 @@ +import gc +from collections import OrderedDict +import os +import logging + +from comfy.cli_args import args + + +class ModelCache: + def __init__(self): + self._cache = OrderedDict() + # ignore gpu and highvram state, may cause OOM error + self._cache_state = False if (args.highvram or args.gpu_only) else True + + def _get_item(self, key): + if self._cache.get(key) is None: + self._cache[key] = {} + return self._cache[key] + + def cache_model(self, key, model): + if not self._cache_state: + return + item = self._get_item(key) + item['model'] = model + + def cache_clip_vision(self, key, clip_vision): + if not self._cache_state: + return + item = self._get_item(key) + item['clip_vision'] = clip_vision + + def cache_vae(self, key, vae): + if not self._cache_state: + return + item = self._get_item(key) + item['vae'] = vae + + def cache_sd(self, key, sd): + if not self._cache_state: + return + assert isinstance(sd, dict) + keys = list(sd.keys()) + values = list(sd.values()) + item = self._get_item(key) + item['sd'] = (keys, values) + + def cache_clip(self, key, clip_key, clip): + item = self._get_item(key) + item[clip_key] = clip + + def refresh_cache(self, key): + if key in self._cache: + self._cache.move_to_end(key) + + def get_item(self, key, prop): + item = self._cache.get(key) + if item is None: + return None + if prop == "sd": + if item.get('sd') is None: + return None + k, values = item.get("sd") + return dict(zip(k, values)) + return item.get(prop) + + def __len__(self): + return len(self._cache) + + @staticmethod + def unpatch_offload_model(model): + model.model_patches_to(model.offload_device) + + def free_one_model_cache(self): + if len(self) == 0: + return + cache_k, item = self._cache.popitem(last=False) + item.pop("sd", None) + + for k in list(item.keys()): + model = item.pop(k, None) + if model is not None: + if hasattr(model, "patcher"): + self.unpatch_offload_model(model.patcher) + else: + self.unpatch_offload_model(model) + + item.clear() + model_dir, model_name = os.path.split(cache_k) + dir_name = os.path.basename(model_dir) + gc.collect() + logging.info(f"Drop model cache: {model_name} ({dir_name})") + + +model_cache = ModelCache() diff --git a/comfy/model_management.py b/comfy/model_management.py index 5c1afd3d658..a4a5ca55e87 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -5,6 +5,7 @@ import torch import sys import platform +from comfy.cache import model_cache class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -466,14 +467,30 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): def load_model_gpu(model): return load_models_gpu([model]) + +def get_cpu_memory(): + return round(get_free_memory(torch.device("cpu")) / 1024 ** 3, 2) + + +def check_and_free_cpu_memory(): + # set free cpu memory size threshold flag to make free action + current_free_memory_cpu = get_cpu_memory() + while (current_free_memory_cpu <= 8) and len(model_cache) > 0: + model_cache.free_one_model_cache() + soft_empty_cache(force=True) + current_free_memory_cpu = get_cpu_memory() + + def cleanup_models(keep_clone_weights_loaded=False): + check_and_free_cpu_memory() to_delete = [] + for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: if not keep_clone_weights_loaded: to_delete = [i] + to_delete - #TODO: find a less fragile way to do this. - elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + # TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model to_delete = [i] + to_delete for i in to_delete: diff --git a/comfy/sd.py b/comfy/sd.py index 343d2a02ccc..58c2550de0d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -19,6 +19,7 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip +from .cache import model_cache import comfy.model_patcher import comfy.lora @@ -464,48 +465,70 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if model_config.clip_vision_prefix is not None: if output_clipvision: - clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + if output_clipvision: + clipvision = model_cache.get_item(ckpt_path, 'clipvision') + if clipvision is None: + clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) if output_model: + model_patcher = model_cache.get_item(ckpt_path, 'model') inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) - offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) - model.load_model_weights(sd, "model.diffusion_model.") + if model_patcher is None: + offload_device = model_management.unet_offload_device() + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) + model.load_model_weights(sd, "model.diffusion_model.") + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, current_device=inital_load_device) + + if inital_load_device != torch.device("cpu"): + logging.info("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd) + vae = model_cache.get_item(ckpt_path, 'vae') + if vae is None: + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + vae = VAE(sd=vae_sd) + clip_key = f"clip_{embedding_directory}" if output_clip: - clip_target = model_config.clip_target() - if clip_target is not None: - clip_sd = model_config.process_clip_state_dict(sd) - if len(clip_sd) > 0: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) - else: - logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") + clip = model_cache.get_item(ckpt_path, clip_key) + if clip is None: + clip_target = model_config.clip_target() + if clip_target is not None: + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + else: + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: logging.debug("left over keys: {}".format(left_over)) - - if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) - if inital_load_device != torch.device("cpu"): - logging.info("loaded straight to GPU") - model_management.load_model_gpu(model_patcher) - + if model: + logging.debug(f"cache model of : {ckpt_path}") + model_cache.cache_model(ckpt_path, model_patcher) + if clip: + logging.debug(f"cache clip of : {ckpt_path}") + model_cache.cache_clip(ckpt_path, clip_key, clip) + if vae: + logging.debug(f"cache vae of : {ckpt_path}") + model_cache.cache_vae(ckpt_path, vae) + if clipvision: + logging.debug(f"cache clipvision of : {ckpt_path}") + model_cache.cache_clipvision(ckpt_path, clipvision) + + model_cache.refresh_cache(clipvision) return (model_patcher, clip, vae, clipvision) diff --git a/comfy/utils.py b/comfy/utils.py index ab47b8f28a2..4a8b88175c4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -6,10 +6,17 @@ import numpy as np from PIL import Image import logging +from comfy.cache import model_cache def load_torch_file(ckpt, safe_load=False, device=None): + # read the sd from cache + cache_sd = model_cache.get_item(ckpt, 'sd') + if cache_sd: + return cache_sd + if device is None: device = torch.device("cpu") + if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: @@ -27,6 +34,8 @@ def load_torch_file(ckpt, safe_load=False, device=None): sd = pl_sd["state_dict"] else: sd = pl_sd + # save the references of Tensor to cache + model_cache.cache_sd(ckpt, sd) return sd def save_torch_file(sd, ckpt, metadata=None):