mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 19:35:57 +08:00
multigpu: refactor deepclone_multigpu + register cached_patcher_init for CLIP/VAE; Select*Device retargets via deepclone
- ModelPatcher.deepclone_multigpu: remove copy.deepcopy fallback. Require cached_patcher_init (raise a descriptive RuntimeError if missing) and always go through clone(model_override=...) with empty backup containers so the per-device clone owns a pristine, unpatched module instead of a deepcopy of an already-loaded/already-patched one. Also call register_load_device on the new patcher so ModelPatcherDynamic per-device bookkeeping (e.g. dynamic_pins) is populated for the requested load device. - comfy/sd.py: register cached_patcher_init on the CLIP and VAE patchers returned by load_checkpoint_guess_config, and on the patcher returned by load_diffusion_model's companion paths. Add load_checkpoint_clip_patcher, load_checkpoint_vae_patcher, and load_vae_patcher reload helpers so the same loader context can be reused to produce per-device clones. - nodes.py: VAELoader registers cached_patcher_init on the produced VAE's patcher when there is a single backing file (skip for pixel_space and composite image-TAESDs which aren't addressable by a single path). - comfy_extras/nodes_multigpu.py: SelectModelDevice / SelectCLIPDevice / SelectVAEDevice now retarget via deepclone_multigpu when the requested device differs from the current load_device, so the consumed model is not just relabeled but actually rehomed onto the chosen device. Verified on runner-2 (2x RTX 4090, comfy-aimdo 0.4.4): - 10/10 focused unit tests (deepclone behavior, missing-factory error path, Select*Device behavior). - Device-switch-after-consumption end-to-end (SD1.5) produces bit-identical PNGs on cuda:0 and cuda:1. - Z Image multigpu CFG split: ~1.90x speedup (10.5s vs 19.9s steady). - Qwen Image multigpu CFG split (real text negative, cfg=4): ~1.69x speedup (32.5s vs 54.8s steady) -- matches pre-refactor numbers. - Baseline (patch stashed) and patched produce identical timings on both models, so the refactor is performance-neutral. Amp-Thread-ID: https://ampcode.com/threads/T-019e5783-b810-74b1-8ca9-09d675de1479 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@ -457,23 +457,38 @@ class ModelPatcher:
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
|
||||
"the loader that produced this model does not support multigpu "
|
||||
"(cached_patcher_init is not initialized). Use a core loader "
|
||||
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
|
||||
"or have the custom loader register a cached_patcher_init factory."
|
||||
)
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# Produce a freshly-loaded patcher from the loader factory so the multigpu
|
||||
# clone owns its own untainted model weights (rather than relying on
|
||||
# copy.deepcopy of an already-patched/already-loaded module).
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
# Override clone()'s normal "share self.model + share backup containers" with
|
||||
# the pristine model from temp_model_patcher plus empty backup containers --
|
||||
# the fresh model has no patches applied, so any deepcopy of self's stale
|
||||
# backup/object_patches_backup/pinned would just propagate dead state that
|
||||
# no longer corresponds to anything in n.model.
|
||||
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
|
||||
n = self.clone(model_override=model_override)
|
||||
# clone() copies hook_backup by reference from self; reset since model is pristine.
|
||||
n.hook_backup = {}
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
if self.cached_patcher_init is not None:
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
n.model = temp_model_patcher.model
|
||||
else:
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
|
||||
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
|
||||
# __init__ only registered its own (default) load_device.
|
||||
if hasattr(n, "register_load_device"):
|
||||
n.register_load_device(n.load_device)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
|
||||
62
comfy/sd.py
62
comfy/sd.py
@ -1727,8 +1727,50 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
|
||||
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
|
||||
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
|
||||
# already-loaded module.
|
||||
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
|
||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
|
||||
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the CLIP returned by load_checkpoint_guess_config."""
|
||||
_, clip, _, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=False,
|
||||
output_clip=True,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return clip.patcher
|
||||
|
||||
|
||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the VAE returned by load_checkpoint_guess_config."""
|
||||
_, _, vae, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=False,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return vae.patcher
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
@ -1954,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
|
||||
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
|
||||
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
|
||||
|
||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
|
||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||
fresh, untainted VAE patcher (no inherited per-device load state, no
|
||||
in-place quantization fallout) for multigpu work-units and the
|
||||
SelectVAEDevice node. The optional ``device`` matches the source loader's
|
||||
VAE initialization path; the deepclone's ``load_device`` still controls
|
||||
where the cloned patcher is targeted.
|
||||
"""
|
||||
if metadata is None:
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
else:
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return vae.patcher
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
@ -49,48 +49,82 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
def _remember_base_devices(patcher: ModelPatcher):
|
||||
"""Stash the original load/offload device on the underlying model.
|
||||
|
||||
Stored on patcher.model (which is shared across patcher clones), so
|
||||
repeated selector applications can recover the loader's original
|
||||
routing when the user picks "default".
|
||||
Stored on patcher.model (which is shared with the input patcher), so
|
||||
later "default" selections can recover the loader's original routing.
|
||||
Only the first Select on a given chain writes these attrs; subsequent
|
||||
deepclones inherit them onto their freshly-loaded model below.
|
||||
"""
|
||||
if not hasattr(patcher.model, "_select_base_load_device"):
|
||||
patcher.model._select_base_load_device = patcher.load_device
|
||||
patcher.model._select_base_offload_device = patcher.offload_device
|
||||
|
||||
|
||||
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
||||
"""Apply *resolved* to a freshly-cloned patcher; respect base devices on default.
|
||||
def _propagate_base_devices(src_model, dst_model):
|
||||
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
|
||||
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
|
||||
dst_model._select_base_load_device = src_model._select_base_load_device
|
||||
dst_model._select_base_offload_device = src_model._select_base_offload_device
|
||||
|
||||
Returns the (possibly newly-replaced) patcher. For CPU on a dynamic
|
||||
patcher, also tries to downgrade to a plain ModelPatcher so the
|
||||
dynamic-only code paths are bypassed (best-effort: silently keeps
|
||||
the dynamic patcher if downgrade is not supported).
|
||||
|
||||
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
|
||||
"""Return a patcher whose actual model weights live on *target_load_device*.
|
||||
|
||||
If *patcher* is already on *target_load_device* we just retarget the
|
||||
(already-cloned) patcher's metadata in place. Otherwise we call
|
||||
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
|
||||
the loader's ``cached_patcher_init`` factory -- the only safe way to
|
||||
move weights that may already be partially loaded onto another device.
|
||||
|
||||
NOTE: reusing the input patcher's model when the requested device
|
||||
matches its current load_device is a deliberate fast path. Anything
|
||||
that has already mutated the original model (e.g. a prior KSampler
|
||||
invocation on the same model) will be observed here. This is by
|
||||
design and documented on the SelectXDeviceNode docstrings -- placing
|
||||
Select X Device after a node that consumes the same model is not
|
||||
recommended.
|
||||
"""
|
||||
if patcher.load_device == target_load_device:
|
||||
# Fast path: weights already on the desired device, just update offload.
|
||||
patcher.offload_device = target_offload_device
|
||||
return patcher
|
||||
src_model = patcher.model
|
||||
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
|
||||
patcher.offload_device = target_offload_device
|
||||
_propagate_base_devices(src_model, patcher.model)
|
||||
if hasattr(patcher, "register_load_device"):
|
||||
patcher.register_load_device(patcher.load_device)
|
||||
return patcher
|
||||
|
||||
|
||||
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
||||
"""Resolve the requested device and produce a patcher routed there.
|
||||
|
||||
For "default" we restore the loader's original load/offload pair.
|
||||
For CPU we pin both load and offload to CPU (and, on a dynamic
|
||||
patcher, downgrade to a plain ModelPatcher so the dynamic-only
|
||||
code paths are bypassed).
|
||||
For an explicit GPU we keep the loader's original offload but
|
||||
target the requested load device; if that differs from the current
|
||||
load device the patcher is deepcloned onto the new device.
|
||||
"""
|
||||
_remember_base_devices(patcher)
|
||||
base_load = patcher.model._select_base_load_device
|
||||
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
|
||||
|
||||
if resolved is None:
|
||||
# "default" -> reset routing to whatever the loader produced
|
||||
patcher.load_device = base_load
|
||||
patcher.offload_device = base_offload
|
||||
elif resolved.type == "cpu":
|
||||
# "default" -> route back to the loader's original devices.
|
||||
return _retarget_patcher(patcher, base_load, base_offload)
|
||||
if resolved.type == "cpu":
|
||||
if patcher.is_dynamic():
|
||||
try:
|
||||
patcher = patcher.clone(disable_dynamic=True)
|
||||
except Exception:
|
||||
# Downgrade unavailable (no cached_patcher_init); fall
|
||||
# back to the existing dynamic patcher.
|
||||
pass
|
||||
# clone(disable_dynamic=True) requires cached_patcher_init; let the
|
||||
# exception surface to the caller (Select*DeviceNode.execute), which
|
||||
# will translate it into a passthrough+log so unsupported loaders
|
||||
# don't hard-fail the workflow.
|
||||
patcher = patcher.clone(disable_dynamic=True)
|
||||
patcher.load_device = resolved
|
||||
patcher.offload_device = resolved
|
||||
else:
|
||||
patcher.load_device = resolved
|
||||
patcher.offload_device = base_offload
|
||||
|
||||
if hasattr(patcher, "register_load_device"):
|
||||
patcher.register_load_device(patcher.load_device)
|
||||
return patcher
|
||||
return patcher
|
||||
return _retarget_patcher(patcher, resolved, base_offload)
|
||||
|
||||
|
||||
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
|
||||
@ -122,6 +156,12 @@ class SelectModelDeviceNode(io.ComfyNode):
|
||||
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||
device is restored to the loader's original choice.
|
||||
|
||||
When the requested device differs from the device the input model is
|
||||
already on, a fresh model is spawned via the loader's reload factory
|
||||
(cached_patcher_init) so the new patcher owns independent weights on
|
||||
the new device. Loaders that don't support multigpu (no factory) will
|
||||
cause the node to pass through unchanged with a warning.
|
||||
|
||||
If the workflow already has MultiGPU CFG Split applied and the chosen
|
||||
GPU collides with one of the existing multigpu clones, that clone is
|
||||
dropped so two patchers don't end up bound to the same device.
|
||||
@ -130,6 +170,13 @@ class SelectModelDeviceNode(io.ComfyNode):
|
||||
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||
the node passes the model through unchanged and logs a message
|
||||
instead of failing.
|
||||
|
||||
NOTE: Placing Select Model Device *after* a node that has already
|
||||
consumed the same model (e.g. a KSampler that ran on this model on
|
||||
the original device) is not recommended -- any state the prior
|
||||
consumer mutated on the original model will be observed when the
|
||||
selected device matches the original (fast path). Place Select Model
|
||||
Device before any consumer of the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -161,7 +208,11 @@ class SelectModelDeviceNode(io.ComfyNode):
|
||||
if resolved is None and device not in (None, "default"):
|
||||
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
|
||||
return io.NodeOutput(model)
|
||||
model = _apply_patcher_device(model, resolved)
|
||||
try:
|
||||
model = _apply_patcher_device(model, resolved)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(model)
|
||||
if resolved is not None:
|
||||
_prune_multigpu_collision(model, model.load_device)
|
||||
return io.NodeOutput(model)
|
||||
@ -208,7 +259,10 @@ class SelectCLIPDeviceNode(io.ComfyNode):
|
||||
if resolved is None and device not in (None, "default"):
|
||||
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
|
||||
return io.NodeOutput(clip)
|
||||
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
||||
try:
|
||||
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
@ -263,13 +317,19 @@ class SelectVAEDeviceNode(io.ComfyNode):
|
||||
if resolved is not None and resolved.type == "cpu":
|
||||
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
|
||||
return io.NodeOutput(vae)
|
||||
vae.patcher = _apply_patcher_device(
|
||||
vae.patcher, resolved,
|
||||
base_offload_override=comfy.model_management.vae_offload_device(),
|
||||
)
|
||||
# VAE caches the working device separately from its patcher.
|
||||
if not hasattr(vae, "_select_base_device"):
|
||||
vae._select_base_device = vae.device
|
||||
try:
|
||||
vae.patcher = _apply_patcher_device(
|
||||
vae.patcher, resolved,
|
||||
base_offload_override=comfy.model_management.vae_offload_device(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(vae)
|
||||
# Keep VAE wrapper in sync with whatever model the patcher now owns;
|
||||
# deepclone_multigpu may have produced a fresh first_stage_model.
|
||||
vae.first_stage_model = vae.patcher.model
|
||||
vae.device = vae._select_base_device if resolved is None else resolved
|
||||
return io.NodeOutput(vae)
|
||||
|
||||
|
||||
9
nodes.py
9
nodes.py
@ -795,6 +795,7 @@ class VAELoader:
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
metadata = None
|
||||
vae_path = None
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||
@ -813,6 +814,14 @@ class VAELoader:
|
||||
metadata["tae_latent_channels"] = 128
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
# Register a reload factory on the patcher so multigpu deepclones
|
||||
# (Select VAE Device, future MultiGPU VAE work-units) can produce
|
||||
# per-device clones from the same loader context. Only set when we
|
||||
# actually have a single backing file -- pixel_space and the
|
||||
# image TAESDs (composed from separate encoder/decoder files via
|
||||
# load_taesd) are not addressable by a single vae_path.
|
||||
if vae_path is not None:
|
||||
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
|
||||
Reference in New Issue
Block a user