Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model cache after loaded #3605

Closed
wants to merge 15 commits into from
100 changes: 100 additions & 0 deletions comfy/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import gc
from collections import OrderedDict
import os
import logging

from comfy.cli_args import args


class ModelCache:
def __init__(self):
self._cache = OrderedDict()
# ignore gpu and highvram state, may cause OOM error
self._cache_state = False if (args.highvram or args.gpu_only) else True

def _get_item(self, key):
if self._cache.get(key) is None:
self._cache[key] = {}
return self._cache[key]

def cache_model(self, key, model):
if not self._cache_state:
return
item = self._get_item(key)
item['model'] = model

def cache_clip_vision(self, key, clip_vision):
if not self._cache_state:
return
item = self._get_item(key)
item['clip_vision'] = clip_vision

def cache_vae(self, key, vae):
if not self._cache_state:
return
item = self._get_item(key)
item['vae'] = vae

def cache_sd(self, key, sd):
if not self._cache_state:
return
assert isinstance(sd, dict)
keys = list(sd.keys())
values = list(sd.values())
item = self._get_item(key)
item['sd'] = (keys, values)

def cache_clip(self, key, clip_key, clip):
if not self._cache_state:
return
item = self._get_item(key)
item[clip_key] = clip

def refresh_cache(self, key):
if key in self._cache:
self._cache.move_to_end(key)

def get_item(self, key, prop):
item = self._cache.get(key)
if item is None:
return None
if prop == "sd":
if item.get('sd') is None:
return None
k, values = item.get("sd")
return dict(zip(k, values))
return item.get(prop)

def __len__(self):
return len(self._cache)

@property
def cache(self):
return self._cache

@staticmethod
def unpatch_offload_model(model):
model.model_patches_to(model.offload_device)

def free_one_model_cache(self):
if len(self) == 0:
return
cache_k, item = self._cache.popitem(last=False)
item.pop("sd", None)

for k in list(item.keys()):
model = item.pop(k, None)
if model is not None:
if hasattr(model, "patcher"):
self.unpatch_offload_model(model.patcher)
else:
self.unpatch_offload_model(model)

item.clear()
model_dir, model_name = os.path.split(cache_k)
dir_name = os.path.basename(model_dir)
gc.collect()
logging.info(f"Drop model cache: {model_name} ({dir_name})")


model_cache = ModelCache()
21 changes: 19 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import sys
import platform
from comfy.cache import model_cache

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -553,6 +554,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
def load_model_gpu(model):
return load_models_gpu([model])

def get_cpu_memory():
return round(get_free_memory(torch.device("cpu")) / 1024 ** 3, 2)


def check_and_free_cpu_memory():
# set free cpu memory size threshold flag to make free action
current_free_memory_cpu = get_cpu_memory()
while (current_free_memory_cpu <= 8) and len(model_cache) > 0:
model_cache.free_one_model_cache()
soft_empty_cache(force=True)
current_free_memory_cpu = get_cpu_memory()


def loaded_models(only_currently_used=False):
output = []
for m in current_loaded_models:
Expand All @@ -563,16 +577,19 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output


def cleanup_models(keep_clone_weights_loaded=False):
check_and_free_cpu_memory()
to_delete = []

for i in range(len(current_loaded_models)):
#TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
# TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 4: # references from .real_model + the .model
to_delete = [i] + to_delete

for i in to_delete:
Expand Down
73 changes: 60 additions & 13 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from . import sd1_clip
from . import sdxl_clip
from .cache import model_cache
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
Expand All @@ -26,6 +27,7 @@
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl


import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
Expand Down Expand Up @@ -522,12 +524,12 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo

def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
out = load_state_dict_guess_config(ckpt_path, sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out

def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
def load_state_dict_guess_config(ckpt_path, sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None
clipvision = None
vae = None
Expand Down Expand Up @@ -558,20 +560,56 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c

if model_config.clip_vision_prefix is not None:
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
if output_clipvision:
clipvision = model_cache.get_item(ckpt_path, 'clipvision')
if clipvision is None:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)

if output_model:
model_patcher = model_cache.get_item(ckpt_path, 'model')
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
if model_patcher is None:
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device,
offload_device=offload_device)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)

if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)

if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
vae = model_cache.get_item(ckpt_path, 'vae')
if vae is None:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)

clip_key = f"clip_{embedding_directory}"
if output_clip:
clip = model_cache.get_item(ckpt_path, clip_key)
if clip is None:
clip_target = model_config.clip_target()
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))

if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
Expand All @@ -595,11 +633,20 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if len(left_over) > 0:
logging.debug("left over keys: {}".format(left_over))

if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
if model:
logging.debug(f"cache model of : {ckpt_path}")
model_cache.cache_model(ckpt_path, model_patcher)
if clip:
logging.debug(f"cache clip of : {ckpt_path}")
model_cache.cache_clip(ckpt_path, clip_key, clip)
if vae:
logging.debug(f"cache vae of : {ckpt_path}")
model_cache.cache_vae(ckpt_path, vae)
if clipvision:
logging.debug(f"cache clipvision of : {ckpt_path}")
model_cache.cache_clip_vision(ckpt_path, clipvision)

model_cache.refresh_cache(ckpt_path)

return (model_patcher, clip, vae, clipvision)

Expand Down
8 changes: 8 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@
from PIL import Image
import logging
import itertools
from comfy.cache import model_cache

def load_torch_file(ckpt, safe_load=False, device=None):
# read the sd from cache
cache_sd = model_cache.get_item(ckpt, 'sd')
if cache_sd:
return cache_sd

if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
Expand All @@ -47,6 +53,8 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd["state_dict"]
else:
sd = pl_sd
# save the references of Tensor to cache
model_cache.cache_sd(ckpt, sd)
return sd

def save_torch_file(sd, ckpt, metadata=None):
Expand Down