Skip to content

Commit

Permalink
Add model cache after loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
david committed Jun 4, 2024
1 parent 71ec5b1 commit 1b1569d
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 33 deletions.
94 changes: 94 additions & 0 deletions comfy/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import gc
from collections import OrderedDict
import os
import logging

from comfy.cli_args import args


class ModelCache:
def __init__(self):
self._cache = OrderedDict()
# ignore gpu and highvram state, may cause OOM error
self._cache_state = False if (args.highvram or args.gpu_only) else True

def _get_item(self, key):
if self._cache.get(key) is None:
self._cache[key] = {}
return self._cache[key]

def cache_model(self, key, model):
if not self._cache_state:
return
item = self._get_item(key)
item['model'] = model

def cache_clip_vision(self, key, clip_vision):
if not self._cache_state:
return
item = self._get_item(key)
item['clip_vision'] = clip_vision

def cache_vae(self, key, vae):
if not self._cache_state:
return
item = self._get_item(key)
item['vae'] = vae

def cache_sd(self, key, sd):
if not self._cache_state:
return
assert isinstance(sd, dict)
keys = list(sd.keys())
values = list(sd.values())
item = self._get_item(key)
item['sd'] = (keys, values)

def cache_clip(self, key, clip_key, clip):
item = self._get_item(key)
item[clip_key] = clip

def refresh_cache(self, key):
if key in self._cache:
self._cache.move_to_end(key)

def get_item(self, key, prop):
item = self._cache.get(key)
if item is None:
return None
if prop == "sd":
if item.get('sd') is None:
return None
k, values = item.get("sd")
return dict(zip(k, values))
return item.get(prop)

def __len__(self):
return len(self._cache)

@staticmethod
def unpatch_offload_model(model):
model.model_patches_to(model.offload_device)

def free_one_model_cache(self):
if len(self) == 0:
return
cache_k, item = self._cache.popitem(last=False)
item.pop("sd", None)

for k in list(item.keys()):
model = item.pop(k, None)
if model is not None:
if hasattr(model, "patcher"):
self.unpatch_offload_model(model.patcher)
else:
self.unpatch_offload_model(model)

item.clear()
model_dir, model_name = os.path.split(cache_k)
dir_name = os.path.basename(model_dir)
gc.collect()
logging.info(f"Drop model cache: {model_name} ({dir_name})")


model_cache = ModelCache()
21 changes: 19 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import sys
import platform
from comfy.cache import model_cache

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -466,14 +467,30 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
def load_model_gpu(model):
return load_models_gpu([model])


def get_cpu_memory():
return round(get_free_memory(torch.device("cpu")) / 1024 ** 3, 2)


def check_and_free_cpu_memory():
# set free cpu memory size threshold flag to make free action
current_free_memory_cpu = get_cpu_memory()
while (current_free_memory_cpu <= 8) and len(model_cache) > 0:
model_cache.free_one_model_cache()
soft_empty_cache(force=True)
current_free_memory_cpu = get_cpu_memory()


def cleanup_models(keep_clone_weights_loaded=False):
check_and_free_cpu_memory()
to_delete = []

for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
# TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model
to_delete = [i] + to_delete

for i in to_delete:
Expand Down
85 changes: 54 additions & 31 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from .cache import model_cache

import comfy.model_patcher
import comfy.lora
Expand Down Expand Up @@ -464,48 +465,70 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o

if model_config.clip_vision_prefix is not None:
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
if output_clipvision:
clipvision = model_cache.get_item(ckpt_path, 'clipvision')
if clipvision is None:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)

if output_model:
model_patcher = model_cache.get_item(ckpt_path, 'model')
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
if model_patcher is None:
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, current_device=inital_load_device)

if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)

if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
vae = model_cache.get_item(ckpt_path, 'vae')
if vae is None:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)

clip_key = f"clip_{embedding_directory}"
if output_clip:
clip_target = model_config.clip_target()
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))

if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
clip = model_cache.get_item(ckpt_path, clip_key)
if clip is None:
clip_target = model_config.clip_target()
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))

if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

left_over = sd.keys()
if len(left_over) > 0:
logging.debug("left over keys: {}".format(left_over))

if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)

if model:
logging.debug(f"cache model of : {ckpt_path}")
model_cache.cache_model(ckpt_path, model_patcher)
if clip:
logging.debug(f"cache clip of : {ckpt_path}")
model_cache.cache_clip(ckpt_path, clip_key, clip)
if vae:
logging.debug(f"cache vae of : {ckpt_path}")
model_cache.cache_vae(ckpt_path, vae)
if clipvision:
logging.debug(f"cache clipvision of : {ckpt_path}")
model_cache.cache_clipvision(ckpt_path, clipvision)

model_cache.refresh_cache(clipvision)
return (model_patcher, clip, vae, clipvision)


Expand Down
9 changes: 9 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
import numpy as np
from PIL import Image
import logging
from comfy.cache import model_cache

def load_torch_file(ckpt, safe_load=False, device=None):
# read the sd from cache
cache_sd = model_cache.get_item(ckpt, 'sd')
if cache_sd:
return cache_sd

if device is None:
device = torch.device("cpu")

if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
Expand All @@ -27,6 +34,8 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd["state_dict"]
else:
sd = pl_sd
# save the references of Tensor to cache
model_cache.cache_sd(ckpt, sd)
return sd

def save_torch_file(sd, ckpt, metadata=None):
Expand Down

0 comments on commit 1b1569d

Please sign in to comment.