mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 09:38:08 +08:00
Compare commits
10 Commits
reintroduc
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
| 7259e664ef | |||
| ae539cfa0a | |||
| 8f82b16993 | |||
| 72fe66a18b | |||
| 07ff14ae02 | |||
| ba1c039a04 | |||
| 6220400ad5 | |||
| af55a2308f | |||
| 3a649984f2 | |||
| a145651cc0 |
@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@ -15,14 +15,13 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@ -39,7 +38,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@ -65,18 +64,6 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@ -90,7 +77,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@ -98,7 +85,6 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -125,38 +111,17 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@ -165,7 +130,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c: ControlBase):
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@ -319,14 +284,6 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@ -357,10 +314,6 @@ class QwenFunControlNet(ControlNet):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def cleanup(self):
|
||||
self.extra_args.pop("base_model", None)
|
||||
super().cleanup()
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
@ -953,14 +906,6 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@ -607,13 +607,9 @@ class HunYuanDiTPlain(nn.Module):
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
|
||||
swap_cfg_halves = context.shape[0] >= 2
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = context.chunk(2, dim = 0)
|
||||
context = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
if context.shape[0] >= 2:
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
@ -661,8 +657,8 @@ class HunYuanDiTPlain(nn.Module):
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = output.chunk(2, dim = 0)
|
||||
output = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
return output
|
||||
if output.shape[0] >= 2:
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
else:
|
||||
return output
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
@ -9,7 +10,7 @@ from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
lock: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
@ -42,6 +43,7 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
@ -55,29 +57,27 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
with info.lock:
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
with info.lock:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@ -28,18 +27,13 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@ -210,91 +204,6 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device("cuda", i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device("xpu", i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device("npu", i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
current = get_torch_device()
|
||||
if current in devices:
|
||||
devices.remove(current)
|
||||
return devices
|
||||
|
||||
def get_gpu_device_options():
|
||||
"""Return list of device option strings for node widgets.
|
||||
|
||||
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||
"""
|
||||
options = ["default", "cpu"]
|
||||
devices = get_all_torch_devices()
|
||||
if len(devices) > 1:
|
||||
for i in range(len(devices)):
|
||||
options.append(f"gpu:{i}")
|
||||
return options
|
||||
|
||||
def resolve_gpu_device_option(option: str):
|
||||
"""Resolve a device option string to a torch.device.
|
||||
|
||||
Returns None for "default" (let the caller use its normal default).
|
||||
Returns torch.device("cpu") for "cpu".
|
||||
For "gpu:N", returns the Nth torch device. Falls back to None if
|
||||
the index is out of range (caller should use default).
|
||||
"""
|
||||
if option is None or option == "default":
|
||||
return None
|
||||
if option == "cpu":
|
||||
return torch.device("cpu")
|
||||
if option.startswith("gpu:"):
|
||||
try:
|
||||
idx = int(option[4:])
|
||||
devices = get_all_torch_devices()
|
||||
if 0 <= idx < len(devices):
|
||||
return devices[idx]
|
||||
else:
|
||||
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
|
||||
return None
|
||||
except (ValueError, IndexError):
|
||||
logging.warning(f"Invalid device option '{option}', using default.")
|
||||
return None
|
||||
logging.warning(f"Unrecognized device option '{option}', using default.")
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def cuda_device_context(device):
|
||||
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||
|
||||
Used when running operations on a non-default CUDA device so that custom
|
||||
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||
device index. The previous device is restored on exit.
|
||||
|
||||
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||
the current device.
|
||||
"""
|
||||
prev = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev = torch.cuda.current_device()
|
||||
if prev != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev is not None:
|
||||
torch.cuda.set_device(prev)
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -583,13 +492,9 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
@ -649,7 +554,7 @@ def ensure_pin_registerable(size, evict_active=False):
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelPatcher):
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@ -657,7 +562,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@ -668,7 +573,6 @@ class LoadedModel:
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@ -1944,34 +1848,7 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
|
||||
@ -23,7 +23,6 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import copy
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -79,15 +78,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@ -333,10 +329,7 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@ -373,8 +366,7 @@ class ModelPatcher:
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
@ -388,8 +380,6 @@ class ModelPatcher:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
@ -448,98 +438,19 @@ class ModelPatcher:
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
if self.cached_patcher_init is not None:
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
n.model = temp_model_patcher.model
|
||||
else:
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@ -1321,7 +1232,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models: list[ModelPatcher] = []
|
||||
all_models = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@ -1375,18 +1286,9 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep, model_options):
|
||||
ignore_multigpu = model_options.get("ignore_multigpu", False)
|
||||
def prepare_state(self, timestep):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep, model_options)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
model_options["ignore_multigpu"] = True
|
||||
try:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options)
|
||||
finally:
|
||||
model_options.pop("ignore_multigpu", None)
|
||||
callback(self, timestep)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@ -1399,18 +1301,12 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@ -1422,28 +1318,6 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
|
||||
@ -1,312 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
def __init__(self, devices: list[torch.device]):
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||
|
||||
for device in devices:
|
||||
wq = queue.Queue()
|
||||
rq = queue.Queue()
|
||||
self._work_queues[device] = wq
|
||||
self._result_queues[device] = rq
|
||||
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||
t.start()
|
||||
self._workers.append(t)
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
return
|
||||
result_q.put((None, e))
|
||||
return
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
break
|
||||
fn, args, kwargs = item
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||
self._work_queues[device].put((fn, args, kwargs))
|
||||
|
||||
def get_result(self, device: torch.device):
|
||||
return self._result_queues[device].get()
|
||||
|
||||
@property
|
||||
def devices(self) -> list[torch.device]:
|
||||
return list(self._work_queues.keys())
|
||||
|
||||
def shutdown(self):
|
||||
for wq in self._work_queues.values():
|
||||
wq.put(None) # sentinel
|
||||
for t in self._workers:
|
||||
t.join(timeout=5.0)
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# only keep model clones that don't go 'past' the intended max_gpu count;
|
||||
# this prunes any inherited multigpu clones whose load_device is no longer allowed
|
||||
# when max_gpus is lowered between runs.
|
||||
allowed_devices = set(limit_extra_devices)
|
||||
allowed_devices.add(model.load_device)
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
|
||||
if len(new_multigpu_models) != len(multigpu_models):
|
||||
model.set_additional_models("multigpu", new_multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
return model
|
||||
|
||||
|
||||
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
||||
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
|
||||
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
|
||||
"""
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus - 1]
|
||||
cloned = copy.copy(upscale_model)
|
||||
existing = getattr(upscale_model, 'multigpu_clones', None)
|
||||
limit_extra_device_set = set(limit_extra_devices)
|
||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
||||
if hasattr(cloned, 'multigpu_clones'):
|
||||
del cloned.multigpu_clones
|
||||
return cloned
|
||||
|
||||
for device in limit_extra_devices:
|
||||
if device in clones:
|
||||
continue
|
||||
clone_source = copy.copy(upscale_model)
|
||||
if hasattr(clone_source, 'multigpu_clones'):
|
||||
del clone_source.multigpu_clones
|
||||
clone_desc = copy.deepcopy(clone_source)
|
||||
clone_desc.model.eval()
|
||||
for p in clone_desc.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
clone_desc.to("cpu")
|
||||
clones[device] = clone_desc
|
||||
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
|
||||
|
||||
cloned.multigpu_clones = clones
|
||||
return cloned
|
||||
|
||||
|
||||
def create_vae_multigpu_deepclones(vae, max_gpus: int):
|
||||
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
|
||||
deepclones, one per extra CUDA device up to ``max_gpus``.
|
||||
"""
|
||||
vae.throw_exception_if_invalid()
|
||||
vae_device = torch.device(vae.device)
|
||||
cloned = copy.copy(vae)
|
||||
if hasattr(cloned, 'multigpu_clones'):
|
||||
del cloned.multigpu_clones
|
||||
if vae_device.type == "cpu":
|
||||
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
|
||||
return cloned
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices()
|
||||
|
||||
def is_vae_device(device):
|
||||
return device.type == vae_device.type and device.index == vae_device.index
|
||||
|
||||
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
|
||||
return cloned
|
||||
|
||||
existing = getattr(vae, 'multigpu_clones', None)
|
||||
limit_extra_device_set = set(limit_extra_devices)
|
||||
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
|
||||
|
||||
for device in limit_extra_devices:
|
||||
if device in clones:
|
||||
continue
|
||||
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
|
||||
clone_vae = copy.copy(vae)
|
||||
if hasattr(clone_vae, 'multigpu_clones'):
|
||||
del clone_vae.multigpu_clones
|
||||
clone_vae.first_stage_model = cloned_patcher.model
|
||||
clone_vae.patcher = cloned_patcher
|
||||
clone_vae.first_stage_model.eval()
|
||||
for p in clone_vae.first_stage_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
clone_vae.first_stage_model.to("cpu")
|
||||
clones[device] = clone_vae
|
||||
logging.info(f"Created CPU VAE deepclone for {device}")
|
||||
|
||||
cloned.multigpu_clones = clones
|
||||
return cloned
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
@ -3,8 +3,6 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@ -121,47 +119,6 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@ -186,8 +143,7 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@ -199,7 +155,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model: BaseModel = model.model
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -245,18 +201,3 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@ -18,7 +16,6 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.multigpu
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
@ -144,7 +141,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list, device=None):
|
||||
def cond_cat(c_list):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -156,8 +153,6 @@ def cond_cat(c_list, device=None):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@ -217,12 +212,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
|
||||
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
|
||||
# aggregation) is duplicated there with per-device scheduling layered on top.
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -254,7 +244,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -355,239 +345,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
|
||||
# accumulation, memory-fit batching, and output aggregation, but adds a
|
||||
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
|
||||
# placement, and MultiGPUThreadPool dispatch around the inner loop.
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||
|
||||
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||
|
||||
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||
|
||||
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||
"""
|
||||
for offset in range(len(devices)):
|
||||
i = (start + offset) % len(devices)
|
||||
if device_load[devices[i]] < conds_per_device:
|
||||
return i, devices[i]
|
||||
raise RuntimeError(
|
||||
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||
f"({conds_per_device}) but conds remain to schedule"
|
||||
)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
index_device = 0
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
index_device, current_device = next_available_device(index_device)
|
||||
remaining_capacity = conds_per_device - device_load[current_device]
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||
device_load[current_device] += len(conds_to_batch)
|
||||
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||
|
||||
if device_load[current_device] >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep.to(device)
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
|
||||
# above are async on CUDA, so the main thread's aggregation
|
||||
# could race with in-flight transfers. CUDA-only QA has not
|
||||
# surfaced this in practice, but before extending multigpu
|
||||
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
|
||||
# here (guarded by `output_device.type == "cuda"`).
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
def _handle_batch_pooled(device, batch_tuple):
|
||||
worker_results = []
|
||||
_handle_batch(device, batch_tuple, worker_results)
|
||||
return worker_results
|
||||
|
||||
results: list[thread_result] = []
|
||||
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||
|
||||
# Submit all GPU work to pool threads
|
||||
pool_devices = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
if thread_pool is not None:
|
||||
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||
pool_devices.append(device)
|
||||
else:
|
||||
# Fallback: no pool, run everything on main thread
|
||||
_handle_batch(device, batch_tuple, results)
|
||||
|
||||
# Collect results from pool workers
|
||||
for device in pool_devices:
|
||||
worker_results, error = thread_pool.get_result(device)
|
||||
if error is not None:
|
||||
raise error
|
||||
results.extend(worker_results)
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@ -886,21 +643,12 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
# Per-device model lookup so multigpu control clones get the matching
|
||||
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
|
||||
device_models: dict = {}
|
||||
patcher = getattr(model, "current_patcher", None)
|
||||
if patcher is not None:
|
||||
for p in patcher.get_additional_models_with_key("multigpu"):
|
||||
device_models[p.load_device] = p.model
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device, device_cnet in x['control'].multigpu_clones.items():
|
||||
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -1143,9 +891,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@ -1154,17 +900,18 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@ -1174,8 +921,8 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@ -1183,6 +930,7 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@ -1237,32 +985,16 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
# Create persistent thread pool for all GPU devices (main + extras)
|
||||
if multigpu_patchers:
|
||||
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
with comfy.model_management.cuda_device_context(device):
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
451
comfy/sd.py
451
comfy/sd.py
@ -335,43 +335,41 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
@ -385,12 +383,8 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
@ -452,12 +446,9 @@ class CLIP:
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@ -972,26 +963,6 @@ class VAE:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@ -1001,49 +972,16 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
memory_shape = samples.shape
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
memory_shape = og_shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = clone_decode_fn_factory(c, dev)
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: decode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -1053,25 +991,6 @@ class VAE:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -1081,7 +1000,6 @@ class VAE:
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@ -1091,24 +1009,8 @@ class VAE:
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = clone_encode_fn_factory(c, dev)
|
||||
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
else:
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
@ -1116,21 +1018,6 @@ class VAE:
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
multigpu_clones = getattr(self, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
functions = {self.device: encode_fn}
|
||||
try:
|
||||
for dev, c in multigpu_clones.items():
|
||||
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
|
||||
c.first_stage_model.to(dev)
|
||||
for dev, c in multigpu_clones.items():
|
||||
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
finally:
|
||||
for c in multigpu_clones.values():
|
||||
c.first_stage_model.to("cpu")
|
||||
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@ -1139,52 +1026,50 @@ class VAE:
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1202,21 +1087,20 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -1229,46 +1113,44 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1294,27 +1176,26 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1829,16 +1710,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
|
||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||
return vae.patcher
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
@ -1865,7 +1742,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
@ -1905,15 +1782,13 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae_device = model_options.get("load_device", None)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
@ -1997,7 +1872,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
@ -2022,7 +1897,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
@ -2064,26 +1939,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
def load_vae_patcher(vae_path, metadata=None, device=None):
|
||||
"""Reload a VAE from disk and return its patcher.
|
||||
|
||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
|
||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||
fresh VAE patcher with no inherited source-device storage tracking. The
|
||||
optional device matches the source loader's VAE initialization path; the
|
||||
cloned patcher's load_device still controls the device targeted by the
|
||||
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
|
||||
carries dynamic-VRAM allocator state forward to the clone, which causes
|
||||
per-device worker threads in tiled encode/decode dispatch to access weights
|
||||
through the source-device buffer."""
|
||||
if metadata is None:
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
else:
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return vae.patcher
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
160
comfy/utils.py
160
comfy/utils.py
@ -28,13 +28,13 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
import threading
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@ -86,7 +86,6 @@ def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
file_lock = threading.Lock()
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
@ -112,7 +111,7 @@ def load_safetensors(ckpt):
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
sd[name] = tensor
|
||||
|
||||
@ -1187,161 +1186,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
|
||||
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
|
||||
|
||||
Round-robin dispatches tile positions across devices via threading. Each thread maintains
|
||||
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
|
||||
formula as the single-device path. Buffers are summed at the end, producing output that is
|
||||
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
|
||||
|
||||
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
|
||||
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
|
||||
available at that granularity).
|
||||
"""
|
||||
devices = list(functions.keys())
|
||||
if len(devices) < 2:
|
||||
only_fn = next(iter(functions.values())) if functions else None
|
||||
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
|
||||
upscale_amount=upscale_amount, out_channels=out_channels,
|
||||
output_device=output_device, downscale=downscale,
|
||||
index_formulas=index_formulas, pbar=pbar)
|
||||
|
||||
dims = len(tile)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
upscale_amount = [upscale_amount] * dims
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
if index_formulas is None:
|
||||
index_formulas = upscale_amount
|
||||
if not (isinstance(index_formulas, (tuple, list))):
|
||||
index_formulas = [index_formulas] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
return up(val) if callable(up) else up * val
|
||||
|
||||
def get_downscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
return up(val) if callable(up) else val / up
|
||||
|
||||
def get_upscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
return up(val) if callable(up) else up * val
|
||||
|
||||
def get_downscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
return up(val) if callable(up) else val / up
|
||||
|
||||
if downscale:
|
||||
get_scale = get_downscale
|
||||
get_pos = get_downscale_pos
|
||||
else:
|
||||
get_scale = get_upscale
|
||||
get_pos = get_upscale_pos
|
||||
|
||||
def mult_list_upscale(a):
|
||||
return [round(get_scale(i, a[i])) for i in range(len(a))]
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
merge_device = torch.device("cpu")
|
||||
|
||||
pbar_lock = threading.Lock() if pbar is not None else None
|
||||
primary_device = devices[0]
|
||||
|
||||
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
|
||||
|
||||
for b in range(samples_staged.shape[0]):
|
||||
s = samples_staged[b:b+1]
|
||||
|
||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||
with torch.inference_mode():
|
||||
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
|
||||
|
||||
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
|
||||
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
|
||||
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
|
||||
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
|
||||
|
||||
worker_errors: list[BaseException] = []
|
||||
worker_lock = threading.Lock()
|
||||
|
||||
def worker(device, my_positions):
|
||||
try:
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
fn = functions[device]
|
||||
local_buf = bufs[device]
|
||||
local_div = divs[device]
|
||||
with torch.inference_mode():
|
||||
for it in my_positions:
|
||||
s_in = s
|
||||
upscaled = []
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
s_in_dev = s_in.to(device, non_blocking=True)
|
||||
ps = fn(s_in_dev).to(merge_device)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
if feather >= mask.shape[d]:
|
||||
continue
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
||||
o = local_buf
|
||||
o_d = local_div
|
||||
ps_view = ps
|
||||
mask_view = mask
|
||||
for d in range(dims):
|
||||
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
|
||||
o = o.narrow(d + 2, upscaled[d], l)
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], l)
|
||||
if l < ps_view.shape[d + 2]:
|
||||
ps_view = ps_view.narrow(d + 2, 0, l)
|
||||
mask_view = mask_view.narrow(d + 2, 0, l)
|
||||
|
||||
o.add_(ps_view * mask_view)
|
||||
o_d.add_(mask_view)
|
||||
|
||||
if pbar is not None:
|
||||
with pbar_lock:
|
||||
pbar.update(1)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
except BaseException as e:
|
||||
with worker_lock:
|
||||
worker_errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
if worker_errors:
|
||||
raise worker_errors[0]
|
||||
|
||||
combined_buf = sum(bufs.values())
|
||||
combined_div = sum(divs.values())
|
||||
output[b:b+1] = combined_buf / combined_div
|
||||
|
||||
return output
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
@ -3066,7 +3065,6 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=poll_path),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
@ -3192,7 +3190,6 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
@ -182,7 +182,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=comfy.model_management.get_gpu_device_options(),
|
||||
options=["default", "cpu"],
|
||||
advanced=True,
|
||||
)
|
||||
],
|
||||
@ -197,12 +197,8 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import cleandoc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so
|
||||
downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
|
||||
|
||||
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
|
||||
Otherwise position is not order-sensitive.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_WorkUnits",
|
||||
display_name="MultiGPU Work Units",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model", optional=True),
|
||||
io.UpscaleModel.Input("upscale_model", optional=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.UpscaleModel.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput:
|
||||
if model is not None:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
if upscale_model is not None:
|
||||
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
|
||||
if vae is not None:
|
||||
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
|
||||
return io.NodeOutput(model, upscale_model, vae)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
|
||||
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
|
||||
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
|
||||
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
|
||||
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
|
||||
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
|
||||
round-robin via next_available_device(). Before re-enabling this node, wire its
|
||||
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
|
||||
which already implements the proportional split) so the input actually affects work
|
||||
distribution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_Options",
|
||||
display_name="MultiGPU Options",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Int.Input("device_index", default=0, min=0, max=64),
|
||||
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
|
||||
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Custom("GPU_OPTIONS").Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
else:
|
||||
gpu_options = gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return io.NodeOutput(gpu_options)
|
||||
|
||||
|
||||
class MultiGPUExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MultiGPUCFGSplitNode,
|
||||
# MultiGPUOptionsNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MultiGPUExtension:
|
||||
return MultiGPUExtension()
|
||||
@ -81,33 +81,13 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
|
||||
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
for dev, desc in multigpu_clones.items():
|
||||
model_management.free_memory(memory_required, dev)
|
||||
desc.to(dev)
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
if multigpu_clones:
|
||||
functions = {device: lambda a: upscale_model(a.float())}
|
||||
for dev, desc in multigpu_clones.items():
|
||||
functions[dev] = lambda a, d=desc: d(a.float())
|
||||
s = comfy.utils.tiled_scale_multidim_multigpu(
|
||||
in_img,
|
||||
functions,
|
||||
tile=(tile, tile),
|
||||
overlap=overlap,
|
||||
upscale_amount=upscale_model.scale,
|
||||
pbar=pbar,
|
||||
output_device=output_device,
|
||||
)
|
||||
else:
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@ -116,9 +96,6 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
raise e
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
if multigpu_clones:
|
||||
for desc in multigpu_clones.values():
|
||||
desc.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(s)
|
||||
|
||||
@ -23,69 +23,6 @@ class ImageOnlyCheckpointLoader:
|
||||
return (out[0], out[3], out[2])
|
||||
|
||||
|
||||
class ImageOnlyCheckpointLoaderDevice:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
device_options = comfy.model_management.get_gpu_device_options()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
},
|
||||
"optional": {
|
||||
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
|
||||
"clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}),
|
||||
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"):
|
||||
return True
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
|
||||
if resolved_model is not None:
|
||||
if resolved_model.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved_model
|
||||
else:
|
||||
model_options["load_device"] = resolved_model
|
||||
|
||||
cv_model_options = {}
|
||||
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device)
|
||||
if resolved_clip is not None:
|
||||
if resolved_clip.type == "cpu":
|
||||
cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip
|
||||
else:
|
||||
cv_model_options["load_device"] = resolved_clip
|
||||
|
||||
# VAE device is passed via model_options["load_device"] which
|
||||
# load_state_dict_guess_config forwards to the VAE constructor.
|
||||
# If vae_device differs from model_device, we override after loading.
|
||||
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
|
||||
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
model_patcher, clip, vae, clip_vision = out[:4]
|
||||
|
||||
# Apply VAE device override if it differs from the model device
|
||||
if resolved_vae is not None and vae is not None:
|
||||
vae.device = resolved_vae
|
||||
if resolved_vae.type == "cpu":
|
||||
offload = resolved_vae
|
||||
else:
|
||||
offload = comfy.model_management.vae_offload_device()
|
||||
vae.patcher.load_device = resolved_vae
|
||||
vae.patcher.offload_device = offload
|
||||
|
||||
return (model_patcher, clip_vision, vae)
|
||||
|
||||
|
||||
class SVD_img2vid_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -212,7 +149,6 @@ class ConditioningSetAreaPercentageVideo:
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||
@ -222,7 +158,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": "Load Checkpoint Image Only (img2vid model)",
|
||||
"ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)",
|
||||
"VideoLinearCFGGuidance": "Video Linear CFG Guidance",
|
||||
"VideoTriangleCFGGuidance": "Video Triangle CFG Guidance",
|
||||
}
|
||||
|
||||
37
main.py
37
main.py
@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
@ -135,7 +136,20 @@ def apply_custom_paths():
|
||||
folder_paths.set_user_directory(user_dir)
|
||||
|
||||
|
||||
# Buffer for prestartup failures. Recorded into `nodes.NODE_STARTUP_ERRORS`
|
||||
# only AFTER the normal `import nodes` line below, so a failing prestartup
|
||||
# script never triggers an early `import nodes` (and therefore `import torch`)
|
||||
# on the error path.
|
||||
_PRESTARTUP_FAILURES: list[dict] = []
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
"""Run every custom_nodes/*/prestartup_script.py once, before importing nodes.
|
||||
|
||||
Failures are buffered into the module-level ``_PRESTARTUP_FAILURES`` list and
|
||||
must be flushed via ``record_node_startup_error`` after ``import nodes`` has
|
||||
happened at its normal bootstrap point.
|
||||
"""
|
||||
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
||||
return
|
||||
|
||||
@ -148,6 +162,15 @@ def execute_prestartup_script():
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
# Buffer the failure - do NOT `import nodes` here, that would drag
|
||||
# torch in before the intended bootstrap point.
|
||||
_PRESTARTUP_FAILURES.append({
|
||||
"module_path": os.path.dirname(script_path),
|
||||
"source": "custom_nodes",
|
||||
"phase": "prestartup",
|
||||
"error": e,
|
||||
"tb": traceback.format_exc(),
|
||||
})
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
@ -200,13 +223,23 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
|
||||
# Flush any prestartup failures that were buffered before `nodes` was
|
||||
# importable. Doing this here (rather than from the prestartup error
|
||||
# handler) keeps the bootstrap order deterministic: `nodes` (and torch)
|
||||
# import at this single line whether prestartup succeeded or failed.
|
||||
if _PRESTARTUP_FAILURES:
|
||||
for _failure in _PRESTARTUP_FAILURES:
|
||||
nodes.record_node_startup_error(**_failure)
|
||||
_PRESTARTUP_FAILURES.clear()
|
||||
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
@ -218,7 +251,7 @@ import comfy.model_patcher
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())):
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
|
||||
217
nodes.py
217
nodes.py
@ -608,73 +608,6 @@ class CheckpointLoaderSimple:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out[:3]
|
||||
|
||||
|
||||
class CheckpointLoaderDevice:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
device_options = comfy.model_management.get_gpu_device_options()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
},
|
||||
"optional": {
|
||||
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
|
||||
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
|
||||
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
|
||||
return True
|
||||
|
||||
def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
|
||||
if resolved_model is not None:
|
||||
if resolved_model.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved_model
|
||||
else:
|
||||
model_options["load_device"] = resolved_model
|
||||
|
||||
te_model_options = {}
|
||||
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
|
||||
if resolved_clip is not None:
|
||||
if resolved_clip.type == "cpu":
|
||||
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
|
||||
else:
|
||||
te_model_options["load_device"] = resolved_clip
|
||||
|
||||
# VAE device is passed via model_options["load_device"] which
|
||||
# load_state_dict_guess_config forwards to the VAE constructor.
|
||||
# If vae_device differs from model_device, we override after loading.
|
||||
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
|
||||
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
|
||||
model_patcher, clip, vae = out[:3]
|
||||
|
||||
# Apply VAE device override if it differs from the model device
|
||||
if resolved_vae is not None and vae is not None:
|
||||
vae.device = resolved_vae
|
||||
if resolved_vae.type == "cpu":
|
||||
offload = resolved_vae
|
||||
else:
|
||||
offload = comfy.model_management.vae_offload_device()
|
||||
vae.patcher.load_device = resolved_vae
|
||||
vae.patcher.offload_device = offload
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
class DiffusersLoader:
|
||||
SEARCH_ALIASES = ["load diffusers model"]
|
||||
|
||||
@ -853,23 +786,15 @@ class VAELoader:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (s.vae_list(s), )},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
return {"required": { "vae_name": (s.vae_list(s), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name, device="default"):
|
||||
def load_vae(self, vae_name):
|
||||
metadata = None
|
||||
vae_path = None
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||
@ -886,14 +811,8 @@ class VAELoader:
|
||||
metadata = {"tae_latent_channels": 128}
|
||||
else:
|
||||
metadata["tae_latent_channels"] = 128
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
# Register a reload factory on the patcher so MultiGPU work-units can use
|
||||
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
|
||||
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
|
||||
if vae_path is not None:
|
||||
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
@ -1018,20 +937,13 @@ class UNETLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype, device="default"):
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
@ -1041,13 +953,6 @@ class UNETLoader:
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
@ -1059,7 +964,7 @@ class CLIPLoader:
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -1068,20 +973,12 @@ class CLIPLoader:
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
model_options = {}
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
@ -1095,7 +992,7 @@ class DualCLIPLoader:
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -1104,10 +1001,6 @@ class DualCLIPLoader:
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
@ -1115,12 +1008,8 @@ class DualCLIPLoader:
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
|
||||
model_options = {}
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return (clip,)
|
||||
@ -2183,7 +2072,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"InpaintModelConditioning": InpaintModelConditioning,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"CheckpointLoaderDevice": CheckpointLoaderDevice,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
@ -2201,7 +2089,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Loaders
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA (Model and CLIP)",
|
||||
"LoraLoaderModelOnly": "Load LoRA",
|
||||
@ -2271,6 +2158,71 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by "<source>:<module_name>"
|
||||
# so that name collisions across custom_nodes / comfy_extras / comfy_api_nodes
|
||||
# do not overwrite each other. Each value contains: source, module_name,
|
||||
# module_path, error, traceback, phase.
|
||||
#
|
||||
# `source` is the same string as the internal `module_parent` used at load
|
||||
# time (e.g. "custom_nodes", "comfy_extras", "comfy_api_nodes"). It is
|
||||
# intentionally a free-form string rather than a fixed enum so the contract
|
||||
# survives node-source layouts evolving (e.g. comfy_api_nodes eventually
|
||||
# moving out of core). Consumers should treat any new value as a new bucket
|
||||
# rather than rejecting it.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def _read_pyproject_metadata(module_path: str) -> dict | None:
|
||||
"""Best-effort extraction of node-pack identity from pyproject.toml.
|
||||
|
||||
Returns a dict with the Comfy Registry-style identity (pack_id,
|
||||
display_name, publisher_id, version, repository) when the module
|
||||
directory contains a pyproject.toml. Returns None when no toml is
|
||||
present or parsing fails for any reason — startup-error tracking
|
||||
must never itself raise.
|
||||
"""
|
||||
if not module_path or not os.path.isdir(module_path):
|
||||
return None
|
||||
toml_path = os.path.join(module_path, "pyproject.toml")
|
||||
if not os.path.isfile(toml_path):
|
||||
return None
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
cfg = config_parser.extract_node_configuration(module_path)
|
||||
if cfg is None:
|
||||
return None
|
||||
meta = {
|
||||
"pack_id": cfg.project.name or None,
|
||||
"display_name": cfg.tool_comfy.display_name or None,
|
||||
"publisher_id": cfg.tool_comfy.publisher_id or None,
|
||||
"version": cfg.project.version or None,
|
||||
"repository": cfg.project.urls.repository or None,
|
||||
}
|
||||
# Drop empty fields so the API payload stays compact.
|
||||
return {k: v for k, v in meta.items() if v}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def record_node_startup_error(
|
||||
*, module_path: str, source: str, phase: str, error: BaseException, tb: str
|
||||
) -> None:
|
||||
"""Record a startup error for a node module so it can be exposed via the API."""
|
||||
module_name = get_module_name(module_path)
|
||||
entry = {
|
||||
"source": source,
|
||||
"module_name": module_name,
|
||||
"module_path": module_path,
|
||||
"error": str(error),
|
||||
"traceback": tb,
|
||||
"phase": phase,
|
||||
}
|
||||
pyproject = _read_pyproject_metadata(module_path)
|
||||
if pyproject:
|
||||
entry["pyproject"] = pyproject
|
||||
NODE_STARTUP_ERRORS[f"{source}:{module_name}"] = entry
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@ -2380,14 +2332,30 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="entrypoint",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
tb = traceback.format_exc()
|
||||
logging.warning(tb)
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
record_node_startup_error(
|
||||
module_path=module_path,
|
||||
source=module_parent,
|
||||
phase="import",
|
||||
error=e,
|
||||
tb=tb,
|
||||
)
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
@ -2502,7 +2470,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_lt_audio.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo==0.4.4
|
||||
comfy-aimdo==0.4.3
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
62
server.py
62
server.py
@ -646,37 +646,18 @@ class PromptServer():
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def system_stats(request):
|
||||
primary_device = comfy.model_management.get_torch_device()
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
cpu_device = comfy.model_management.torch.device("cpu")
|
||||
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
required_templates_version = FrontendManager.get_required_templates_version()
|
||||
comfy_package_versions = FrontendManager.get_comfy_package_versions()
|
||||
|
||||
# Report every torch device visible to multigpu, with the primary
|
||||
# device first so existing clients that read devices[0] keep working.
|
||||
torch_devices = comfy.model_management.get_all_torch_devices()
|
||||
if primary_device in torch_devices:
|
||||
torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device]
|
||||
else:
|
||||
torch_devices = [primary_device] + list(torch_devices)
|
||||
|
||||
device_entries = []
|
||||
for d in torch_devices:
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True)
|
||||
device_entries.append({
|
||||
"name": comfy.model_management.get_torch_device_name(d),
|
||||
"type": d.type,
|
||||
"index": d.index,
|
||||
"vram_total": vram_total,
|
||||
"vram_free": vram_free,
|
||||
"torch_vram_total": torch_vram_total,
|
||||
"torch_vram_free": torch_vram_free,
|
||||
})
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
"os": sys.platform,
|
||||
@ -692,7 +673,17 @@ class PromptServer():
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": device_entries
|
||||
"devices": [
|
||||
{
|
||||
"name": device_name,
|
||||
"type": device.type,
|
||||
"index": device.index,
|
||||
"vram_total": vram_total,
|
||||
"vram_free": vram_free,
|
||||
"torch_vram_total": torch_vram_total,
|
||||
"torch_vram_free": torch_vram_free,
|
||||
}
|
||||
]
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@ -774,6 +765,29 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/node_startup_errors")
|
||||
async def get_node_startup_errors(request):
|
||||
"""Return startup errors recorded during node loading, grouped by source.
|
||||
|
||||
Group errors by source so the frontend/Manager can render them in
|
||||
distinct sections. ``source`` is the same string as the
|
||||
``module_parent`` used at load time (e.g. ``"custom_nodes"``,
|
||||
``"comfy_extras"``, ``"comfy_api_nodes"``) and is left as a
|
||||
free-form string so the contract survives node-source layouts
|
||||
evolving. The response only contains source buckets that actually
|
||||
had a failure; consumers should not assume any particular set of
|
||||
keys is always present.
|
||||
|
||||
``module_path`` is stripped because the absolute on-disk path is
|
||||
internal detail that the frontend has no use for.
|
||||
"""
|
||||
grouped: dict[str, dict[str, dict]] = {}
|
||||
for entry in nodes.NODE_STARTUP_ERRORS.values():
|
||||
source = entry.get("source", "custom_nodes")
|
||||
public_entry = {k: v for k, v in entry.items() if k != "module_path"}
|
||||
grouped.setdefault(source, {})[entry["module_name"]] = public_entry
|
||||
return web.json_response(grouped)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def install_fake_comfy_aimdo(monkeypatch):
|
||||
package = types.ModuleType("comfy_aimdo")
|
||||
package.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
|
||||
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
|
||||
module = types.ModuleType(f"comfy_aimdo.{name}")
|
||||
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
|
||||
setattr(package, name, module)
|
||||
|
||||
|
||||
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
|
||||
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
|
||||
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
|
||||
|
||||
scale = 1.1
|
||||
|
||||
def upscale(a):
|
||||
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
|
||||
|
||||
samples = torch.ones((1, 1, 11))
|
||||
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
|
||||
|
||||
actual = comfy.utils.tiled_scale_multidim_multigpu(
|
||||
samples,
|
||||
{device: upscale for device in devices},
|
||||
tile=(5,),
|
||||
overlap=2,
|
||||
upscale_amount=scale,
|
||||
out_channels=1,
|
||||
output_device="cpu",
|
||||
)
|
||||
expected = comfy.utils.tiled_scale_multidim(
|
||||
samples,
|
||||
upscale,
|
||||
tile=(5,),
|
||||
overlap=2,
|
||||
upscale_amount=scale,
|
||||
out_channels=1,
|
||||
output_device="cpu",
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape == (1, 1, 12)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
|
||||
class FakeModel:
|
||||
def __init__(self):
|
||||
self.param = torch.nn.Parameter(torch.ones(1))
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def parameters(self):
|
||||
return [self.param]
|
||||
|
||||
class FakeDescriptor:
|
||||
def __init__(self):
|
||||
self.model = FakeModel()
|
||||
self.device = None
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
first_device = torch.device("cpu:0")
|
||||
second_device = torch.device("cpu:1")
|
||||
stale_device = torch.device("cpu:2")
|
||||
existing_clone = FakeDescriptor()
|
||||
stale_clone = FakeDescriptor()
|
||||
source = FakeDescriptor()
|
||||
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
|
||||
fake_model_management = types.ModuleType("comfy.model_management")
|
||||
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
|
||||
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
|
||||
import comfy
|
||||
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
|
||||
import comfy.multigpu
|
||||
importlib.reload(comfy.multigpu)
|
||||
|
||||
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
|
||||
|
||||
assert cloned is not source
|
||||
assert cloned.multigpu_clones[first_device] is existing_clone
|
||||
assert stale_device not in cloned.multigpu_clones
|
||||
assert second_device in cloned.multigpu_clones
|
||||
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
|
||||
assert cloned.multigpu_clones[second_device].device == "cpu"
|
||||
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
|
||||
|
||||
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
|
||||
assert single_gpu_clone is not source
|
||||
assert not hasattr(single_gpu_clone, "multigpu_clones")
|
||||
|
||||
|
||||
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
|
||||
install_fake_comfy_aimdo(monkeypatch)
|
||||
import comfy.sd
|
||||
importlib.reload(comfy.sd)
|
||||
|
||||
class FakeVAE:
|
||||
def __init__(self):
|
||||
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
|
||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
vae = FakeVAE()
|
||||
metadata = {"format": "checkpoint"}
|
||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
||||
monkeypatch.setattr(
|
||||
comfy.sd,
|
||||
"load_state_dict_guess_config",
|
||||
lambda *args, **kwargs: (model_patcher, None, vae, None),
|
||||
)
|
||||
|
||||
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
|
||||
|
||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
||||
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
|
||||
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
|
||||
|
||||
|
||||
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
|
||||
install_fake_comfy_aimdo(monkeypatch)
|
||||
import comfy.sd
|
||||
importlib.reload(comfy.sd)
|
||||
|
||||
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
|
||||
placeholder_vae = types.SimpleNamespace()
|
||||
metadata = {"format": "checkpoint"}
|
||||
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
|
||||
monkeypatch.setattr(
|
||||
comfy.sd,
|
||||
"load_state_dict_guess_config",
|
||||
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
|
||||
)
|
||||
|
||||
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
|
||||
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
|
||||
146
tests-unit/node_startup_errors_test.py
Normal file
146
tests-unit/node_startup_errors_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""Tests for the custom node startup error tracking introduced for
|
||||
Comfy-Org/ComfyUI-Launcher#303.
|
||||
|
||||
Covers:
|
||||
- load_custom_node populates NODE_STARTUP_ERRORS with the correct source
|
||||
for each module_parent (custom_nodes / comfy_extras / comfy_api_nodes).
|
||||
- Composite keying prevents collisions between modules with the same name
|
||||
in different sources.
|
||||
- record_node_startup_error stores the expected fields.
|
||||
- pyproject.toml metadata is attached when present and omitted when absent.
|
||||
"""
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_startup_errors():
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
yield
|
||||
nodes.NODE_STARTUP_ERRORS.clear()
|
||||
|
||||
|
||||
def _write_broken_module(tmp_path, name: str) -> str:
|
||||
path = tmp_path / f"{name}.py"
|
||||
path.write_text(textwrap.dedent("""\
|
||||
# Deliberately broken module to exercise startup-error tracking.
|
||||
raise RuntimeError("boom from " + __name__)
|
||||
"""))
|
||||
return str(path)
|
||||
|
||||
|
||||
def test_record_node_startup_error_fields(tmp_path):
|
||||
err = ValueError("kaboom")
|
||||
nodes.record_node_startup_error(
|
||||
module_path=str(tmp_path / "my_pack"),
|
||||
source="custom_nodes",
|
||||
phase="import",
|
||||
error=err,
|
||||
tb="traceback-text",
|
||||
)
|
||||
assert "custom_nodes:my_pack" in nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:my_pack"]
|
||||
assert entry["source"] == "custom_nodes"
|
||||
assert entry["module_name"] == "my_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert entry["error"] == "kaboom"
|
||||
assert entry["traceback"] == "traceback-text"
|
||||
assert entry["module_path"].endswith("my_pack")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"module_parent",
|
||||
["custom_nodes", "comfy_extras", "comfy_api_nodes"],
|
||||
)
|
||||
async def test_load_custom_node_records_source(tmp_path, module_parent):
|
||||
# `source` in the entry should be the same string as `module_parent`.
|
||||
module_path = _write_broken_module(tmp_path, "broken_pack")
|
||||
|
||||
success = await nodes.load_custom_node(module_path, module_parent=module_parent)
|
||||
assert success is False
|
||||
|
||||
key = f"{module_parent}:broken_pack"
|
||||
assert key in nodes.NODE_STARTUP_ERRORS, nodes.NODE_STARTUP_ERRORS
|
||||
entry = nodes.NODE_STARTUP_ERRORS[key]
|
||||
assert entry["source"] == module_parent
|
||||
assert entry["module_name"] == "broken_pack"
|
||||
assert entry["phase"] == "import"
|
||||
assert "boom from" in entry["error"]
|
||||
assert "RuntimeError" in entry["traceback"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_collision_across_sources(tmp_path):
|
||||
# Same module name registered as both a custom node and a comfy_extra;
|
||||
# composite keying should keep both entries.
|
||||
cn_dir = tmp_path / "cn"
|
||||
extras_dir = tmp_path / "extras"
|
||||
cn_dir.mkdir()
|
||||
extras_dir.mkdir()
|
||||
cn_path = _write_broken_module(cn_dir, "nodes_audio")
|
||||
extras_path = _write_broken_module(extras_dir, "nodes_audio")
|
||||
|
||||
assert await nodes.load_custom_node(cn_path, module_parent="custom_nodes") is False
|
||||
assert await nodes.load_custom_node(extras_path, module_parent="comfy_extras") is False
|
||||
|
||||
assert "custom_nodes:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert "comfy_extras:nodes_audio" in nodes.NODE_STARTUP_ERRORS
|
||||
assert (
|
||||
nodes.NODE_STARTUP_ERRORS["custom_nodes:nodes_audio"]["module_path"]
|
||||
!= nodes.NODE_STARTUP_ERRORS["comfy_extras:nodes_audio"]["module_path"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_attaches_pyproject_metadata(tmp_path):
|
||||
pack_dir = tmp_path / "MyCoolPack"
|
||||
pack_dir.mkdir()
|
||||
(pack_dir / "__init__.py").write_text("raise RuntimeError('boom')\n")
|
||||
(pack_dir / "pyproject.toml").write_text(textwrap.dedent("""\
|
||||
[project]
|
||||
name = "comfyui-mycoolpack"
|
||||
version = "1.2.3"
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "example"
|
||||
DisplayName = "My Cool Pack"
|
||||
"""))
|
||||
|
||||
success = await nodes.load_custom_node(str(pack_dir), module_parent="custom_nodes")
|
||||
assert success is False
|
||||
|
||||
entry = nodes.NODE_STARTUP_ERRORS["custom_nodes:MyCoolPack"]
|
||||
assert "pyproject" in entry, entry
|
||||
py = entry["pyproject"]
|
||||
assert py["pack_id"] == "comfyui-mycoolpack"
|
||||
assert py["display_name"] == "My Cool Pack"
|
||||
assert py["publisher_id"] == "example"
|
||||
assert py["version"] == "1.2.3"
|
||||
assert py["repository"] == "https://github.com/example/comfyui-mycoolpack"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_no_pyproject_skips_metadata(tmp_path):
|
||||
# Single-file extras-style module: no pyproject.toml exists alongside it,
|
||||
# so the entry must not contain a 'pyproject' key.
|
||||
module_path = _write_broken_module(tmp_path, "lonely")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="comfy_extras") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["comfy_extras:lonely"]
|
||||
assert "pyproject" not in entry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_node_arbitrary_module_parent_passes_through(tmp_path):
|
||||
# `source` is a free-form string — an unknown module_parent (e.g. a future
|
||||
# node-source bucket) should be recorded as-is, not coerced or rejected.
|
||||
module_path = _write_broken_module(tmp_path, "future_pack")
|
||||
assert await nodes.load_custom_node(module_path, module_parent="future_source") is False
|
||||
entry = nodes.NODE_STARTUP_ERRORS["future_source:future_pack"]
|
||||
assert entry["source"] == "future_source"
|
||||
Reference in New Issue
Block a user