multigpu: refactor deepclone_multigpu + register cached_patcher_init for CLIP/VAE; Select*Device retargets via deepclone

- ModelPatcher.deepclone_multigpu: remove copy.deepcopy fallback. Require
  cached_patcher_init (raise a descriptive RuntimeError if missing) and
  always go through clone(model_override=...) with empty backup containers
  so the per-device clone owns a pristine, unpatched module instead of a
  deepcopy of an already-loaded/already-patched one. Also call
  register_load_device on the new patcher so ModelPatcherDynamic per-device
  bookkeeping (e.g. dynamic_pins) is populated for the requested load
  device.

- comfy/sd.py: register cached_patcher_init on the CLIP and VAE patchers
  returned by load_checkpoint_guess_config, and on the patcher returned by
  load_diffusion_model's companion paths. Add load_checkpoint_clip_patcher,
  load_checkpoint_vae_patcher, and load_vae_patcher reload helpers so the
  same loader context can be reused to produce per-device clones.

- nodes.py: VAELoader registers cached_patcher_init on the produced VAE's
  patcher when there is a single backing file (skip for pixel_space and
  composite image-TAESDs which aren't addressable by a single path).

- comfy_extras/nodes_multigpu.py: SelectModelDevice / SelectCLIPDevice /
  SelectVAEDevice now retarget via deepclone_multigpu when the requested
  device differs from the current load_device, so the consumed model is
  not just relabeled but actually rehomed onto the chosen device.

Verified on runner-2 (2x RTX 4090, comfy-aimdo 0.4.4):
- 10/10 focused unit tests (deepclone behavior, missing-factory error path,
  Select*Device behavior).
- Device-switch-after-consumption end-to-end (SD1.5) produces bit-identical
  PNGs on cuda:0 and cuda:1.
- Z Image multigpu CFG split: ~1.90x speedup (10.5s vs 19.9s steady).
- Qwen Image multigpu CFG split (real text negative, cfg=4): ~1.69x
  speedup (32.5s vs 54.8s steady) -- matches pre-refactor numbers.
- Baseline (patch stashed) and patched produce identical timings on both
  models, so the refactor is performance-neutral.

Amp-Thread-ID: https://ampcode.com/threads/T-019e5783-b810-74b1-8ca9-09d675de1479
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski
2026-05-23 19:11:48 -07:00
parent 5c2e34ca4e
commit bece6b2aec
4 changed files with 192 additions and 46 deletions

View File

@ -457,23 +457,38 @@ class ModelPatcher:
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
if self.cached_patcher_init is None:
raise RuntimeError(
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
"the loader that produced this model does not support multigpu "
"(cached_patcher_init is not initialized). Use a core loader "
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
"or have the custom loader register a cached_patcher_init factory."
)
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# Produce a freshly-loaded patcher from the loader factory so the multigpu
# clone owns its own untainted model weights (rather than relying on
# copy.deepcopy of an already-patched/already-loaded module).
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
# Override clone()'s normal "share self.model + share backup containers" with
# the pristine model from temp_model_patcher plus empty backup containers --
# the fresh model has no patches applied, so any deepcopy of self's stale
# backup/object_patches_backup/pinned would just propagate dead state that
# no longer corresponds to anything in n.model.
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
n = self.clone(model_override=model_override)
# clone() copies hook_backup by reference from self; reset since model is pristine.
n.hook_backup = {}
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
if self.cached_patcher_init is not None:
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
n.model = temp_model_patcher.model
else:
n.model = copy.deepcopy(n.model)
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
# __init__ only registered its own (default) load_device.
if hasattr(n, "register_load_device"):
n.register_load_device(n.load_device)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled

View File

@ -1727,8 +1727,50 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
# already-loaded module.
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
factory for the CLIP returned by load_checkpoint_guess_config."""
_, clip, _, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=False,
output_clip=True,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return clip.patcher
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
factory for the VAE returned by load_checkpoint_guess_config."""
_, _, vae, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=False,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
@ -1954,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh, untainted VAE patcher (no inherited per-device load state, no
in-place quantization fallout) for multigpu work-units and the
SelectVAEDevice node. The optional ``device`` matches the source loader's
VAE initialization path; the deepclone's ``load_device`` still controls
where the cloned patcher is targeted.
"""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -49,48 +49,82 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
def _remember_base_devices(patcher: ModelPatcher):
"""Stash the original load/offload device on the underlying model.
Stored on patcher.model (which is shared across patcher clones), so
repeated selector applications can recover the loader's original
routing when the user picks "default".
Stored on patcher.model (which is shared with the input patcher), so
later "default" selections can recover the loader's original routing.
Only the first Select on a given chain writes these attrs; subsequent
deepclones inherit them onto their freshly-loaded model below.
"""
if not hasattr(patcher.model, "_select_base_load_device"):
patcher.model._select_base_load_device = patcher.load_device
patcher.model._select_base_offload_device = patcher.offload_device
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
"""Apply *resolved* to a freshly-cloned patcher; respect base devices on default.
def _propagate_base_devices(src_model, dst_model):
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
dst_model._select_base_load_device = src_model._select_base_load_device
dst_model._select_base_offload_device = src_model._select_base_offload_device
Returns the (possibly newly-replaced) patcher. For CPU on a dynamic
patcher, also tries to downgrade to a plain ModelPatcher so the
dynamic-only code paths are bypassed (best-effort: silently keeps
the dynamic patcher if downgrade is not supported).
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
"""Return a patcher whose actual model weights live on *target_load_device*.
If *patcher* is already on *target_load_device* we just retarget the
(already-cloned) patcher's metadata in place. Otherwise we call
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
the loader's ``cached_patcher_init`` factory -- the only safe way to
move weights that may already be partially loaded onto another device.
NOTE: reusing the input patcher's model when the requested device
matches its current load_device is a deliberate fast path. Anything
that has already mutated the original model (e.g. a prior KSampler
invocation on the same model) will be observed here. This is by
design and documented on the SelectXDeviceNode docstrings -- placing
Select X Device after a node that consumes the same model is not
recommended.
"""
if patcher.load_device == target_load_device:
# Fast path: weights already on the desired device, just update offload.
patcher.offload_device = target_offload_device
return patcher
src_model = patcher.model
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
patcher.offload_device = target_offload_device
_propagate_base_devices(src_model, patcher.model)
if hasattr(patcher, "register_load_device"):
patcher.register_load_device(patcher.load_device)
return patcher
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
"""Resolve the requested device and produce a patcher routed there.
For "default" we restore the loader's original load/offload pair.
For CPU we pin both load and offload to CPU (and, on a dynamic
patcher, downgrade to a plain ModelPatcher so the dynamic-only
code paths are bypassed).
For an explicit GPU we keep the loader's original offload but
target the requested load device; if that differs from the current
load device the patcher is deepcloned onto the new device.
"""
_remember_base_devices(patcher)
base_load = patcher.model._select_base_load_device
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
if resolved is None:
# "default" -> reset routing to whatever the loader produced
patcher.load_device = base_load
patcher.offload_device = base_offload
elif resolved.type == "cpu":
# "default" -> route back to the loader's original devices.
return _retarget_patcher(patcher, base_load, base_offload)
if resolved.type == "cpu":
if patcher.is_dynamic():
try:
patcher = patcher.clone(disable_dynamic=True)
except Exception:
# Downgrade unavailable (no cached_patcher_init); fall
# back to the existing dynamic patcher.
pass
# clone(disable_dynamic=True) requires cached_patcher_init; let the
# exception surface to the caller (Select*DeviceNode.execute), which
# will translate it into a passthrough+log so unsupported loaders
# don't hard-fail the workflow.
patcher = patcher.clone(disable_dynamic=True)
patcher.load_device = resolved
patcher.offload_device = resolved
else:
patcher.load_device = resolved
patcher.offload_device = base_offload
if hasattr(patcher, "register_load_device"):
patcher.register_load_device(patcher.load_device)
return patcher
return patcher
return _retarget_patcher(patcher, resolved, base_offload)
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
@ -122,6 +156,12 @@ class SelectModelDeviceNode(io.ComfyNode):
- "gpu:N" pins the load device to the Nth available GPU; the offload
device is restored to the loader's original choice.
When the requested device differs from the device the input model is
already on, a fresh model is spawned via the loader's reload factory
(cached_patcher_init) so the new patcher owns independent weights on
the new device. Loaders that don't support multigpu (no factory) will
cause the node to pass through unchanged with a warning.
If the workflow already has MultiGPU CFG Split applied and the chosen
GPU collides with one of the existing multigpu clones, that clone is
dropped so two patchers don't end up bound to the same device.
@ -130,6 +170,13 @@ class SelectModelDeviceNode(io.ComfyNode):
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
the node passes the model through unchanged and logs a message
instead of failing.
NOTE: Placing Select Model Device *after* a node that has already
consumed the same model (e.g. a KSampler that ran on this model on
the original device) is not recommended -- any state the prior
consumer mutated on the original model will be observed when the
selected device matches the original (fast path). Place Select Model
Device before any consumer of the model.
"""
@classmethod
@ -161,7 +208,11 @@ class SelectModelDeviceNode(io.ComfyNode):
if resolved is None and device not in (None, "default"):
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
return io.NodeOutput(model)
model = _apply_patcher_device(model, resolved)
try:
model = _apply_patcher_device(model, resolved)
except RuntimeError as e:
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
return io.NodeOutput(model)
if resolved is not None:
_prune_multigpu_collision(model, model.load_device)
return io.NodeOutput(model)
@ -208,7 +259,10 @@ class SelectCLIPDeviceNode(io.ComfyNode):
if resolved is None and device not in (None, "default"):
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
return io.NodeOutput(clip)
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
try:
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
except RuntimeError as e:
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
return io.NodeOutput(clip)
@ -263,13 +317,19 @@ class SelectVAEDeviceNode(io.ComfyNode):
if resolved is not None and resolved.type == "cpu":
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
return io.NodeOutput(vae)
vae.patcher = _apply_patcher_device(
vae.patcher, resolved,
base_offload_override=comfy.model_management.vae_offload_device(),
)
# VAE caches the working device separately from its patcher.
if not hasattr(vae, "_select_base_device"):
vae._select_base_device = vae.device
try:
vae.patcher = _apply_patcher_device(
vae.patcher, resolved,
base_offload_override=comfy.model_management.vae_offload_device(),
)
except RuntimeError as e:
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
return io.NodeOutput(vae)
# Keep VAE wrapper in sync with whatever model the patcher now owns;
# deepclone_multigpu may have produced a fresh first_stage_model.
vae.first_stage_model = vae.patcher.model
vae.device = vae._select_base_device if resolved is None else resolved
return io.NodeOutput(vae)

View File

@ -795,6 +795,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
metadata = None
vae_path = None
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
@ -813,6 +814,14 @@ class VAELoader:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
# Register a reload factory on the patcher so multigpu deepclones
# (Select VAE Device, future MultiGPU VAE work-units) can produce
# per-device clones from the same loader context. Only set when we
# actually have a single backing file -- pixel_space and the
# image TAESDs (composed from separate encoder/decoder files via
# load_taesd) are not addressable by a single vae_path.
if vae_path is not None:
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
return (vae,)
class ControlNetLoader: