Compare commits

..

158 Commits

Author SHA1 Message Date
b64112870c Revert "Add UPSCALE_MODEL lane to MultiGPU CFG Split"
This reverts commit 74b0a826ea.

The UPSCALE_MODEL lane will land as part of a follow-up PR alongside the
tiled VAE lane, separated from the threaded-loader fix PR (#14052) to
keep the upstream merge focused.
2026-05-22 16:37:16 -07:00
40bf50b32c Revert "Add tiled VAE lane to MultiGPU Work Units"
This reverts commit 4d3d68e473.

The tiled VAE lane will land as part of a follow-up PR alongside the
UPSCALE_MODEL lane, separated from the threaded-loader fix PR (#14052)
to keep the upstream merge focused.
2026-05-22 16:37:10 -07:00
cb83c41db7 Merge pull request #14052 from rattus128/prs/worksplit-t-load-fix
fixup threaded loader with worksplit multi-gpu
2026-05-22 16:36:33 -07:00
4d3d68e473 Add tiled VAE lane to MultiGPU Work Units 2026-05-22 13:42:21 -05:00
74b0a826ea Add UPSCALE_MODEL lane to MultiGPU CFG Split
Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler
that dispatches per-device tile functions through the existing
MultiGPUThreadPool and merges per-device CPU output buffers in deterministic
key order. The worker only catches BaseException at the thread boundary to
funnel errors to the main thread; bare torch.cuda.set_device and
torch.cuda.synchronize calls inside the worker fail loud if the device is
not CUDA, which is part of the primitive's contract.

Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model
descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident
until execute time and are returned to CPU afterward.

ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when
a multigpu descriptor is attached; the single-device path runs unchanged
when no clones are present.
2026-05-22 13:41:48 -05:00
7a18f9affb comfy-aimdo 0.4.4
Comfy-aimdo 0.4.4 contains a small bugfix to allow recovery of a hostbuf
after full truncation.

This pattern doesnt happen as a general rule, but does happen in the
upcoming worksplit-multigpu branch.
2026-05-23 01:00:30 +10:00
df17b560c5 memory_management: replace thread refusal with mutex
This was an attempt to be a fast path by ensuring the file slice was
created by the owning thread and refusing without needing ot mutex
but worksplit-multigpu doesnt work that way. Go mutex.

Shoot me for overthinking next time.
2026-05-23 01:00:30 +10:00
b649502c9c Report all torch devices from /system_stats
The /system_stats endpoint was returning a hardcoded single-element
devices list built from get_torch_device(), which only reflects the
primary CUDA device. On multi-GPU systems this hides the additional
devices from frontends / tooling (the API surface that enables multigpu
support discovery). Switch to iterating get_all_torch_devices(), with
the primary device kept first so existing clients reading devices[0]
keep working.

(Worksplit-multigpu-only: get_all_torch_devices is the multigpu helper
introduced on this branch; master's /system_stats remains unchanged.)

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 13:04:54 -07:00
2ed396c769 Mark non-NVIDIA multigpu gaps with TODOs in _handle_batch
Two CodeRabbit findings from #7063 (#13 and #14) are deferred because
worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a
TODO at the unconditional torch.cuda.set_device call and at the
post-aggregation point so the required guards/synchronize are easy to
find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 12:47:43 -07:00
d0b9dbb5a6 Merge remote-tracking branch 'origin/master' into worksplit-multigpu
Brings in 18 commits from master so worksplit-multigpu does not regress
fixes that landed on main since the last sync:

- #13699 Hunyuan 3D 2.1 batch-size fixes (overlap with our own backport;
  conflict resolved in favor of the shape>=2 gate that binds
  swap_cfg_halves once and reuses it for the output swap-back)
- #14031 ModelPatcherDynamic lora reshape / backup restore fix
- #13802 Multi-threaded model load (memory_management / pinned_memory /
  model_management / aimdo plumbing)
- #12679 lanczos single-channel tensor fix
- #14010 Stable Audio 3 support
- assorted partner-node, openapi, workflow-template, and tooling updates

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 12:17:59 -07:00
fd79f22bdf Backport Hunyuan 3D 2.1 attention batch-size fixes from #13699
CrossAttention.kv.view and Attention.qkv_combined.view both hardcoded
batch=1 in the reshape, crashing or silently mis-shaping whenever the
actual batch dimension was greater than 1. These were fixed on master
in #13699 as part of the same patch that gated the chunk(2) swap, but
worksplit-multigpu only picked up the chunk(2) gate. Bring the two
view() fixes over so we have parity with master.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 12:17:24 -07:00
019261ed96 Simplify Hunyuan 3D 2.1 swap_cfg_halves gate to a shape check
The previous gate (len(cond_or_uncond) == 2 and set == {0, 1}) was
intended to skip the cond/uncond swap when only one half was present
under MultiGPU CFG Split, but it was too restrictive: it also skipped
batch_size > 1 + CFG (cond_or_uncond like [0, 0, 1, 1] or [0,0,0,0,
1,1,1,1]), where chunk(2) still splits the batch cleanly into a cond
half and an uncond half and the swap is still required.

Switch to context.shape[0] >= 2, matching the parallel fix landed on
master in #13699. The swap is a permutation-invariant no-op when the
two halves don't form a CFG pair (since the output swap_cfg_halves
block immediately undoes the permutation), so the only thing the gate
actually needs to do is guard against chunk(2) on a batch of one.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 12:14:02 -07:00
822a3ecf73 Note _calc_cond_batch and _calc_cond_batch_multigpu must stay in sync
Per review feedback on #7063. The two functions share the conds-by-hooks
accumulation, memory-fit batching, and per-chunk output aggregation; the
multigpu variant adds per-device scheduling, .to(device) placement,
per-device patcher/control lookup, and thread-pool dispatch around the
inner loop. Documenting the relationship without extracting helpers --
extraction can land after the initial worksplit-multigpu release once
both paths have settled.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 11:47:53 -07:00
1417b711ce Fix CodeRabbit findings in worksplit-multigpu (#14017)
Fix CodeRabbit findings in worksplit-multigpu
2026-05-21 11:42:08 -07:00
a18dd219d5 Pass per-device model to multigpu control clones in pre_run_control
QwenFunControlNet.pre_run stashes model.diffusion_model into extra_args,
which the control_model then uses for forward passes (img_in, txt_in,
pe_embedder, time_text_embed). With multigpu, every per-device control
clone was being pre_run with the base model on GPU0, so secondary
devices would invoke those modules with parameters on GPU0 and inputs
on their own device, raising 'Expected all tensors to be on the same
device'. Build a device -> per-device BaseModel lookup from the
patcher's additional multigpu models and pass each clone the model on
its own device. Falls back to the base model when no per-device match
is found (single-GPU path and the case where cnet.multigpu_clones lags
the patcher's clone set).

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 11:40:49 -07:00
963621603c Free QwenFunControlNet base_model reference in cleanup
QwenFunControlNet.pre_run stashes the model's diffusion_model into
self.extra_args['base_model'], but ControlBase.cleanup never clears
extra_args. The diffusion_model reference therefore lingered between
sampling runs, blocking ComfyUI's model offload/eviction logic from
freeing the UNet and -- for multigpu -- holding one such reference per
per-device control clone (defeating the max_gpus pruning added in this
PR). Override cleanup to drop the entry; super().cleanup() already
recurses into multigpu_clones so each per-device clone pops its own.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 11:35:54 -07:00
adde1239b1 Restore prepare_state backward-compatible signature
Drop the new ignore_multigpu positional argument from prepare_state and
from the ON_PREPARE_STATE callbacks; pass the flag via model_options
instead. This restores the original 3-arg callback signature so existing
custom-node ON_PREPARE_STATE handlers keep working unchanged, while
still letting prepare_state's recursive call into multigpu_clones
short-circuit.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
2026-05-21 11:35:39 -07:00
4d9106dced Document --cuda-device comma format and MultiGPU Options relative_speed gap
Two doc-only changes addressing minor CodeRabbit findings on PR #7063:

* cli_args.py: clarify --cuda-device help text to document the required comma-separated format ('0' or '0,1'), matching how the value is consumed by CUDA_VISIBLE_DEVICES in main.py.

* nodes_multigpu.py: add a docstring NOTE on the (currently unregistered) MultiGPUOptionsNode explaining that its relative_speed input is plumbed through to model_options['multigpu_options'] but is not yet consulted by the cond scheduler, which still uses uniform round-robin via next_available_device(). Wire relative_speed into the scheduler before re-enabling the node.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 20:48:59 -07:00
ac0a90c323 Use cond_shapes in multigpu memory-fit check (parity with single-GPU path)
The multigpu cond-batching loop called model.memory_required(input_shape) without conditioning shapes, while the single-GPU path at line 279 passes cond_shapes. Large conditioning tensors (e.g. video prompts, control inputs) were therefore under-counted, risking OOM at runtime when the chosen batch size was too large. Match the single-GPU pattern by building cond_shapes from each batched cond's conditioning dict and passing it to memory_required.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 19:52:03 -07:00
dd85851efe Prune inherited multigpu clones when max_gpus is lowered
create_multigpu_deepclones cloned the existing 'multigpu' additional_models list verbatim and never pruned entries beyond limit_extra_devices. If a workflow was previously prepared for more GPUs, reducing max_gpus would leave stale clones attached and eligible for later scheduling. Replace the TODO block with a real prune that keeps only clones whose load_device is either the model's load_device or in limit_extra_devices, and re-match clones if anything was removed.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 16:46:45 -07:00
ba417750a7 Fix get_all_torch_devices for XPU/NPU and guard remove()
torch.device(i) defaults to CUDA, so XPU/NPU branches were producing 'cuda:N' devices that don't match get_torch_device() output ('xpu:N'/'npu:N'). This caused devices.remove(get_torch_device()) to raise ValueError when exclude_current=True on non-NVIDIA hardware. Use explicit device strings, and guard the remove() with a membership check for safety.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 16:46:38 -07:00
9a681ccfc9 Guard cached_patcher_init when output_model is False
load_checkpoint_guess_config_clip_only() calls load_checkpoint_guess_config() with output_model=False, leaving out[0] as None. The subsequent unconditional assignment of cached_patcher_init crashed with AttributeError, breaking CLIP-only checkpoint loading entirely. Guard the assignment with a None check.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 16:46:31 -07:00
50d1dd6273 Fix MultiGPU Options node discarding cloned GPUOptionsGroup
GPUOptionsGroup.clone() returns a new instance, but the return value was discarded, causing the node to mutate the upstream caller's group in-place. When multiple MultiGPU Options nodes share an input group, each node's additions would leak into earlier siblings. Assign the clone result back to gpu_options so each node owns its own copy.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
2026-05-20 16:46:23 -07:00
ff766e5cfa Merge remote-tracking branch 'origin/master' into merge-master-into-worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019e4352-d45e-75bc-8ed7-ed3a7f6d129a
Co-authored-by: Amp <amp@ampcode.com>

# Conflicts:
#	comfy/ldm/sam3/detector.py
#	comfy/ldm/sam3/tracker.py
#	comfy/model_base.py
#	comfy/quant_ops.py
#	comfy/supported_models.py
#	comfy_api_nodes/apis/bytedance.py
#	comfy_api_nodes/nodes_bytedance.py
#	comfy_api_nodes/nodes_openai.py
#	comfy_extras/frame_interpolation_models/film_net.py
#	comfy_extras/frame_interpolation_models/ifnet.py
#	comfy_extras/nodes_ace.py
#	comfy_extras/nodes_frame_interpolation.py
#	comfy_extras/nodes_lt_audio.py
#	comfy_extras/nodes_sam3.py
#	comfy_extras/nodes_video_model.py
#	folder_paths.py
#	nodes.py
#	requirements.txt
2026-05-19 21:43:51 -07:00
819c7c0702 Refactor MultiGPU scheduler for readability and termination safety (#14001)
Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device
scheduling. No change to batching decisions or memory checks for any
valid input.

Changes:

* Replace re-summed batched_to_run_length with a per-device load
  dict (device_load), so capacity checks are O(1) and use a single
  source of truth.
* Extract device selection into next_available_device(), which scans
  at most len(devices) positions and raises if no device has
  remaining capacity. This makes the 'skip a full device' rule live
  in one place instead of two and guarantees the outer while loop
  cannot spin forever on a scheduling bug.
* Drop the unused current_device assignment before the outer loop
  and the index_device % len(devices) modulo dance (now handled
  inside next_available_device).
* Minor cleanups: list comprehensions for total_conds, conds_to_batch,
  and the devices list.
2026-05-19 21:23:56 -07:00
9e3ede1406 Fix MultiGPU scheduler capacity accounting (#14000)
Fixes _calc_cond_batch_multigpu so that:

1. conds_per_device uses real division before math.ceil. The previous
   expression math.ceil(total_conds // len(devices)) applied integer
   floor division first, making ceil a no-op. For 3 conds across 2
   devices this produced conds_per_device=1 instead of 2.

2. The scheduling loop skips devices that have already reached
   capacity instead of appending empty batch groups. Without this
   guard, the loop could repeatedly emit zero-length groups for a
   full device, leaving sampling stuck at 0/N until timeout.

Reproduces with an Omnigen2 image workflow that produces three
condition entries scheduled across two CUDA devices. With the fix
the scheduler assigns conds_per_device=2 and splits the batches as
2 + 1 across the two devices, allowing sampling to complete.

Original fix authored and validated by @pollockjj in
pollockjj/ComfyUI#64.

Co-authored-by: John Pollock <pollockjj@gmail.com>
2026-05-19 20:11:53 -07:00
a61e2bbb85 Add device selection on Image Only Load Checkpoint (CORE-158) (#13748)
* Add device selection on Image Only Load Checkpoint

* Rename variables

* Update variable name

* Fix linting
2026-05-06 21:49:23 -07:00
1b96430c60 Merge master into worksplit-multigpu (#13546)
* fix: pin SQLAlchemy>=2.0 in requirements.txt (fixes #13036) (#13316)

* Refactor io to IO in nodes_ace.py (#13485)

* Bump comfyui-frontend-package to 1.42.12 (#13489)

* Make the ltx audio vae more native. (#13486)

* feat(api-nodes): add automatic downscaling of videos for ByteDance 2 nodes (#13465)

* Support standalone LTXV audio VAEs (#13499)

* [Partner Nodes]  added 4K resolution for Veo models; added Veo 3 Lite model (#13330)

* feat(api nodes): added 4K resolution for Veo models; added Veo 3 Lite model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* increase poll_interval from 5 to 9

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* Bump comfyui-frontend-package to 1.42.14 (#13493)

* Add gpt-image-2 as version option (#13501)

* Allow logging in comfy app files. (#13505)

* chore: update workflow templates to v0.9.59 (#13507)

* fix(veo): reject 4K resolution for veo-3.0 models in Veo3VideoGenerationNode (#13504)

The tooltip on the resolution input states that 4K is not available for
veo-3.1-lite or veo-3.0 models, but the execute guard only rejected the
lite combination. Selecting 4K with veo-3.0-generate-001 or
veo-3.0-fast-generate-001 would fall through and hit the upstream API
with an invalid request.

Broaden the guard to match the documented behavior and update the error
message accordingly.

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* feat: RIFE and FILM frame interpolation model support (CORE-29) (#13258)

* initial RIFE support

* Also support FILM

* Better RAM usage, reduce FILM VRAM peak

* Add model folder placeholder

* Fix oom fallback frame loss

* Remove torch.compile for now

* Rename model input

* Shorter input type name

---------

* fix: use Parameter assignment for Stable_Zero123 cc_projection weights (fixes #13492) (#13518)

On Windows with aimdo enabled, disable_weight_init.Linear uses lazy
initialization that sets weight and bias to None to avoid unnecessary
memory allocation. This caused a crash when copy_() was called on the
None weight attribute in Stable_Zero123.__init__.

Replace copy_() with direct torch.nn.Parameter assignment, which works
correctly on both Windows (aimdo enabled) and other platforms.

* Derive InterruptProcessingException from BaseException (#13523)

* bump manager version to 4.2.1 (#13516)

* ModelPatcherDynamic: force cast stray weights on comfy layers (#13487)

the mixed_precision ops can have input_scale parameters that are used
in tensor math but arent a weight or bias so dont get proper VRAM
management. Treat these as force-castable parameters like the non comfy
weight, random params are buffers already are.

* Update logging level for invalid version format (#13526)

* [Partner Nodes] add SD2 real human support (#13509)

* feat(api-nodes): add SD2 real human support

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* fix: add validation before uploading Assets

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* Add asset_id and group_id displaying on the node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* extend poll_op to use instead of custom async cycle

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* added the polling for the "Active" status after asset creation

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* updated tooltip for group_id

* allow usage of real human in the ByteDance2FirstLastFrame node

* add reference count limits

* corrected price in status when input assets contain video

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: SAM (segment anything) 3.1 support (CORE-34) (#13408)

* [Partner Nodes] GPTImage: fix price badges, add new resolutions (#13519)

* fix(api-nodes): fixed price badges, add new resolutions

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* proper calculate the total run cost when "n > 1"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* chore: update workflow templates to v0.9.61 (#13533)

* chore: update embedded docs to v0.4.4 (#13535)

* add 4K resolution to Kling nodes (#13536)

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* Fix LTXV Reference Audio node (#13531)

* comfy-aimdo 0.2.14: Hotfix async allocator estimations (#13534)

This was doing an over-estimate of VRAM used by the async allocator when lots
of little small tensors were in play.

Also change the versioning scheme to == so we can roll forward aimdo without
worrying about stable regressions downstream in comfyUI core.

* Disable sageattention for SAM3 (#13529)

Causes Nans

* execution: Add anti-cycle validation (#13169)

Currently if the graph contains a cycle, the just inifitiate recursions,
hits a catch all then throws a generic error against the output node
that seeded the validation. Instead, fail the offending cycling mode
chain and handlng it as an error in its own right.

Co-authored-by: guill <jacob.e.segal@gmail.com>

* chore: update workflow templates to v0.9.62 (#13539)

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Octopus <liyuan851277048@icloud.com>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Comfy Org PR Bot <snomiao+comfy-pr@gmail.com>
Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com>
Co-authored-by: AustinMroz <austin@comfy.org>
Co-authored-by: Daxiong (Lin) <contact@comfyui-wiki.com>
Co-authored-by: Matt Miller <matt@miller-media.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>
Co-authored-by: Dr.Lt.Data <128333288+ltdrdata@users.noreply.github.com>
Co-authored-by: rattus <46076784+rattus128@users.noreply.github.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-04-23 19:20:14 -07:00
aa464b36b3 Multi-GPU device selection for loader nodes + CUDA context fixes (#13483)
* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2)

Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db
Co-authored-by: Amp <amp@ampcode.com>

* Add GPU device selection to all loader nodes

- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
  in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
  from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
  to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
  silently fall back to default

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Add VALIDATE_INPUTS to skip device combo validation for workflow portability

When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded
on a 1-GPU machine, the combo validation would reject the unknown value.
VALIDATE_INPUTS with the device parameter bypasses combo validation for
that input only, allowing resolve_gpu_device_option to handle the
graceful fallback at runtime.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Set CUDA device context in outer_sample to match model load_device

Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options

- resolve_gpu_device_option: reject negative indices (gpu:-1)
- UNETLoader: set offload_device when cpu is selected
- CheckpointLoaderSimple: pass te_model_options for CLIP device,
  set offload_device for cpu, pass load_device to VAE
- load_diffusion_model_state_dict: respect offload_device from model_options
- load_state_dict_guess_config: respect offload_device, pass load_device to VAE

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix CUDA device context for CLIP encoding and VAE encode/decode

Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
  CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU

Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Extract cuda_device_context manager, fix tiled VAE methods

Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.

Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample

Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
  on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Restore CheckpointLoaderSimple, add CheckpointLoaderDevice

Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.

Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

---------

Co-authored-by: Amp <amp@ampcode.com>
2026-04-23 19:10:33 -07:00
7b8b3673ff comfy-aimdo: 0.0.214 (#13532)
Cut pre-release 0.0.214 off aimdo master to pickup async mem accounting
fix.
2026-04-23 19:09:56 -07:00
b502bcfff9 Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2026-04-20 02:38:33 -07:00
37deccb0d4 Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2) (#13478) 2026-04-20 02:37:18 -07:00
f0d550bd02 Minor updates for worksplit_gpu with comfy-aimdo (#13419)
* main: init all visible cuda devices in aimdo

* mp: call vbars_analyze for the GPU in question

* requirements: bump aimdo to pre-release version
2026-04-15 22:49:01 -07:00
48deb15c0e Simplify multigpu dispatch: run all devices on pool threads (#13340)
Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
2026-04-09 01:15:57 -07:00
4b93c4360f Implement persistent thread pool for multi-GPU CFG splitting (#13329)
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.

- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available
2026-04-08 05:39:07 -07:00
da3864436c Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2026-04-08 05:08:38 -07:00
b418fb1582 Fix device mismatch: update LoadedModel.device when _switch_parent swaps to parent patcher
When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent
switches the weakref to point at the parent (main) ModelPatcher. However, it was not
updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1).
On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing
an assertion failure (device_to != self.load_device).

Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:59:38 -07:00
20803749c3 Add detailed multigpu debug logging to load_models_gpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:53:36 -07:00
3fab720be9 Add debug logging for device mismatch in ModelPatcherDynamic.load
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:45:55 -07:00
afdddcee66 Re-enable comfy-kitchen cuda backend for multigpu testing
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:32:52 -07:00
1d8e379f41 Rename MultiGPU Work Units to MultiGPU CFG Split
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 08:00:20 -07:00
5f4fcd19e7 Simplify multigpu nodes: default max_gpus=2, remove gpu_options input, disable Options node
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:30:32 -07:00
d52dcbc88f Rewrite multigpu nodes to V3 format
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:23:13 -07:00
84f465e791 Set CUDA device at start of multigpu threads to avoid multithreading bugs
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-30 07:07:54 -07:00
be35378986 Merge branch 'master' into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b
Co-authored-by: Amp <amp@ampcode.com>

# Conflicts:
#	comfy/samplers.py
2026-03-30 06:24:55 -07:00
f410d28b33 Merge origin/master into worksplit-multigpu
Amp-Thread-ID: https://ampcode.com/threads/T-019d009d-e059-7623-85ca-401168168516
Co-authored-by: Amp <amp@ampcode.com>
2026-03-18 04:21:30 -07:00
f4b99bc623 Made multigpu deepclone load model from disk to avoid needing to deepclone actual model object, fixed issues with merge, turn off cuda backend as it causes device mismatch issue with rope (and potentially other ops), will investigate 2026-02-17 04:55:00 -08:00
df2fd4c869 Merge branch 'master' into worksplit-multigpu 2026-02-17 02:53:06 -08:00
4661d1db5a Bring patches changes from _calc_cond_batch into _calc_cond_batch_multigpu 2025-10-15 17:34:36 -07:00
b326a544d5 Merge branch 'master' into worksplit-multigpu 2025-10-15 17:33:02 -07:00
d89dd5f0b0 Satisfy ruff 2025-10-13 22:00:34 -07:00
8cbbf0be6c Merge branch 'master' into worksplit-multigpu 2025-10-13 21:53:14 -07:00
c2115a4bac Merge branch 'master' into worksplit-multigpu 2025-09-24 23:45:26 -07:00
bb44c2ecb9 Merge branch 'master' into worksplit-multigpu 2025-09-18 14:20:27 -07:00
efcd8280d6 Merge branch 'master' into worksplit-multigpu 2025-09-11 20:59:47 -07:00
9e9c129cd0 Merge remote-tracking branch 'origin/master' into worksplit-multigpu 2025-08-29 23:36:19 -07:00
ac14ee68c0 Merge branch 'master' into worksplit-multigpu 2025-08-18 19:51:24 -07:00
2c8f485434 Merge branch 'master' into worksplit-multigpu 2025-08-18 00:29:52 -07:00
383f9b34cb Merge branch 'master' into worksplit-multigpu 2025-08-17 16:02:44 -07:00
b0741c7e5b Merge branch 'master' into worksplit-multigpu 2025-08-15 16:50:04 -07:00
1489399cb5 Merge branch 'master' into worksplit-multigpu 2025-08-13 19:47:08 -07:00
3677943fa5 Merge branch 'master' into worksplit-multigpu 2025-08-13 14:06:09 -07:00
cfb63bfcd7 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-08-11 14:09:58 -07:00
962c3c832c Merge branch 'master' into worksplit-multigpu 2025-08-11 14:09:41 -07:00
6ea69369ce Merge branch 'master' into worksplit-multigpu 2025-08-07 23:24:02 -07:00
b4f559b34d Merge branch 'master' into worksplit-multigpu 2025-08-04 20:23:19 -07:00
df122a7dba Merge branch 'master' into worksplit-multigpu 2025-08-01 12:31:57 -07:00
67e906aa64 Merge branch 'master' into worksplit-multigpu 2025-07-31 04:00:22 -07:00
382f84a826 Merge branch 'master' into worksplit-multigpu 2025-07-29 17:17:29 -07:00
9cca36fa2b Merge branch 'master' into worksplit-multigpu 2025-07-29 12:47:36 -07:00
5d5024296d Merge branch 'master' into worksplit-multigpu 2025-07-28 06:17:24 -07:00
3b90a30178 Merge branch 'master' into worksplit-multigpu-wip 2025-07-27 01:03:25 -07:00
3c4104652b Merge branch 'master' into worksplit-multigpu-wip 2025-07-22 11:42:23 -07:00
9855baaab3 Merge branch 'master' into worksplit-multigpu 2025-07-09 03:57:30 -05:00
d53479a197 Merge branch 'master' into worksplit-multigpu 2025-07-01 17:33:05 -05:00
443a795850 Merge branch 'master' into worksplit-multigpu 2025-06-24 00:49:24 -05:00
431dec8e53 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-06-24 00:48:58 -05:00
44e053c26d Improve error handling for multigpu threads 2025-06-24 00:48:51 -05:00
1ae98932f1 Merge branch 'master' into worksplit-multigpu 2025-06-17 04:58:56 -05:00
0336b0ace8 Merge branch 'master' into worksplit-multigpu 2025-06-01 02:39:26 -07:00
8ae25235ec Merge branch 'master' into worksplit-multigpu 2025-05-21 12:01:27 -07:00
9726eac475 Merge branch 'master' into worksplit-multigpu 2025-05-12 19:29:13 -05:00
272e8d42c1 Merge branch 'master' into worksplit-multigpu 2025-04-22 22:40:00 -05:00
6211d2be5a Merge branch 'master' into worksplit-multigpu 2025-04-19 17:36:23 -05:00
8be711715c Make unload_all_models account for all devices 2025-04-19 17:35:54 -05:00
b5cccf1325 Merge branch 'master' into worksplit-multigpu 2025-04-18 15:39:34 -05:00
2a54a904f4 Merge branch 'master' into worksplit-multigpu 2025-04-16 19:26:48 -05:00
ed6f92c975 Merge branch 'master' into worksplit-multigpu 2025-04-16 16:53:57 -05:00
adc66c0698 Merge branch 'master' into worksplit-multigpu 2025-04-16 14:23:56 -05:00
ccd5c01e5a Merge branch 'master' into worksplit-multigpu 2025-04-09 09:17:12 -05:00
2fa9affcc1 Merge branch 'master' into worksplit-multigpu 2025-04-08 22:52:17 -05:00
407a5a656f Rollback core of last commit due to weird behavior 2025-03-28 02:48:11 -05:00
9ce9ff8ef8 Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone 2025-03-28 15:29:44 +08:00
63567c0ce8 Merge branch 'master' into worksplit-multigpu 2025-03-27 22:36:46 -05:00
a786ce5ead Merge branch 'master' into worksplit-multigpu 2025-03-26 22:26:26 -05:00
4879b47648 Merge branch 'master' into worksplit-multigpu 2025-03-18 22:19:32 -05:00
5ccec33c22 Merge branch 'worksplit-multigpu' of https://github.com/comfyanonymous/ComfyUI into worksplit-multigpu 2025-03-17 14:27:39 -05:00
219d3cd0d0 Merge branch 'master' into worksplit-multigpu 2025-03-17 14:26:35 -05:00
c4ba399475 Merge branch 'master' into worksplit-multigpu 2025-03-15 09:12:09 -05:00
cc928a786d Merge branch 'master' into worksplit-multigpu 2025-03-13 20:59:11 -05:00
6e144b98c4 Merge branch 'master' into worksplit-multigpu 2025-03-09 00:00:38 -06:00
6dca17bd2d Satisfy ruff linting 2025-03-03 23:08:29 -06:00
5080105c23 Merge branch 'master' into worksplit-multigpu 2025-03-03 22:56:53 -06:00
093914a247 Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code 2025-03-03 22:56:13 -06:00
605893d3cf Merge branch 'master' into worksplit-multigpu 2025-02-24 19:23:16 -06:00
048f4f0b3a Merge branch 'master' into worksplit-multigpu 2025-02-17 19:35:58 -06:00
d2504fb701 Merge branch 'master' into worksplit-multigpu 2025-02-11 22:34:51 -06:00
b03763bca6 Merge branch 'multigpu_support' into worksplit-multigpu 2025-02-07 13:27:49 -06:00
476aa79b64 Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups 2025-02-06 08:44:07 -06:00
441cfd1a7a Merge branch 'master' into multigpu_support 2025-02-06 08:10:48 -06:00
99a5c1068a Merge branch 'master' into multigpu_support 2025-02-02 03:19:18 -06:00
02747cde7d Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu 2025-01-29 11:10:23 -06:00
0b3233b4e2 Merge remote-tracking branch 'origin/master' into multigpu_support 2025-01-28 06:11:07 -06:00
eda866bf51 Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet 2025-01-27 06:25:48 -06:00
e3298b84de Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support 2025-01-26 09:34:20 -06:00
c7feef9060 Cast transformer_options for multigpu 2025-01-26 05:29:27 -06:00
51af7fa1b4 Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets 2025-01-25 06:05:01 -06:00
46969c380a Initial MultiGPU support for controlnets 2025-01-24 05:39:38 -06:00
5db4277449 Make sure additional_models are unloaded as well when perform 2025-01-23 19:06:05 -06:00
02a4d0ad7d Added unload_model_and_clones to model_management.py to allow unloading only relevant models 2025-01-23 01:20:00 -06:00
ef137ac0b6 Merge branch 'multigpu_support' of https://github.com/kosinkadink/ComfyUI into multigpu_support 2025-01-20 04:34:39 -06:00
328d4f16a9 Make WeightHooks compatible with MultiGPU, clean up some code 2025-01-20 04:34:26 -06:00
bdbcb85b8d Merge branch 'multigpu_support' of https://github.com/Kosinkadink/ComfyUI into multigpu_support 2025-01-20 00:51:42 -06:00
6c9e94bae7 Merge branch 'master' into multigpu_support 2025-01-20 00:51:37 -06:00
bfce723311 Initial work on multigpu_clone function, which will account for additional_models getting cloned 2025-01-17 03:31:28 -06:00
31f5458938 Merge branch 'master' into multigpu_support 2025-01-16 18:25:05 -06:00
2145a202eb Merge branch 'master' into multigpu_support 2025-01-15 19:58:28 -06:00
25818dc848 Added a 'max_gpus' input 2025-01-14 13:45:14 -06:00
198953cd08 Add nodes_multigpu.py to loaded nodes 2025-01-14 12:24:55 -06:00
ec16ee2f39 Merge branch 'master' into multigpu_support 2025-01-13 20:21:06 -06:00
d5088072fb Make test node for multigpu instead of storing it in just a local __init__.py 2025-01-13 20:20:25 -06:00
8d4b50158e Merge branch 'master' into multigpu_support 2025-01-11 20:16:42 -06:00
e88c6c03ff Fix cond_cat to not try to cast anything that doesn't have a 'to' function 2025-01-10 23:05:24 -06:00
d3cf2b7b24 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-10 20:24:37 -06:00
7448f02b7c Initial proof of concept of giving splitting cond sampling between multiple GPUs 2025-01-08 03:33:05 -06:00
871258aa72 Add get_all_torch_devices to get detected devices intended for current torch hardware device 2025-01-07 21:06:03 -06:00
66838ebd39 Merge branch 'comfyanonymous:master' into multigpu_support 2025-01-07 20:11:27 -06:00
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
29 changed files with 1636 additions and 4638 deletions

View File

@ -1,519 +0,0 @@
name: Backport Release
on:
workflow_dispatch:
inputs:
commit:
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
required: true
type: string
permissions:
contents: read
pull-requests: read
checks: read
jobs:
backport-release:
name: Create backport release
runs-on: ubuntu-latest
environment: backport release
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
with:
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
fetch-tags: true
- name: Configure git
run: |
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Resolve source branch from commit SHA
id: resolve
env:
SOURCE_COMMIT: ${{ inputs.commit }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
# and we will be comparing this value against API responses (PR head
# SHA, ref tips) that always return the full form.
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
exit 1
fi
# Fetch all remote branches so we can search for which one(s) point
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
# history of the checked-out ref but does not necessarily populate
# every refs/remotes/origin/*, so do it explicitly.
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
# Verify the commit actually exists in this repo's object DB.
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
exit 1
fi
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
# branch must point at it. If zero, the commit isn't anyone's tip
# (likely stale, force-pushed past, or never the PR head). If more
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
# to guess — the operator must give us a unique branch to release.
mapfile -t matching_branches < <(
git for-each-ref \
--format='%(refname:strip=3)' \
--points-at="${SOURCE_COMMIT}" \
refs/remotes/origin/ \
| grep -vx 'HEAD' || true
)
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
exit 1
fi
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
for b in "${matching_branches[@]}"; do
echo "::error:: - ${b}"
done
echo "::error::Refusing to proceed with an ambiguous source branch."
exit 1
fi
source_branch="${matching_branches[0]}"
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
exit 1
fi
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
- name: Determine latest stable release
id: latest
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
set -euo pipefail
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
# comparison of each component. We DO NOT use `sort -V` because it treats
# v0.19.99 as higher than v0.20.1.
latest_tag="$(
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
| sort -k1,1n -k2,2n -k3,3n \
| tail -n1 \
| awk '{print $4}'
)"
if [[ -z "${latest_tag}" ]]; then
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
exit 1
fi
# Parse components
ver="${latest_tag#v}"
major="${ver%%.*}"
rest="${ver#*.}"
minor="${rest%%.*}"
patch="${rest#*.}"
new_patch=$((patch + 1))
new_version="v${major}.${minor}.${new_patch}"
release_branch="release/v${major}.${minor}"
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
echo "major=${major}" >> "$GITHUB_OUTPUT"
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
echo "Latest stable release: ${latest_tag} (${latest_sha})"
echo "New version will be: ${new_version}"
echo "Release branch: ${release_branch}"
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
# Use the user-provided SHA directly rather than re-resolving the branch
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
# someone pushing to the branch mid-run.
source_sha="${SOURCE_COMMIT}"
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
# We capture rev-list into a variable and grep against a here-string
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
# `grep -q` would exit on first match and SIGPIPE the still-streaming
# `rev-list`, propagating exit 141 as a spurious "not found".
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
fi
# Additionally, every commit added on top of the tag (the set we are
# about to publish) must itself be a descendant of the tag along
# first-parent — i.e. no sibling commits from master sneak in via a
# non-first-parent path. Enforce by requiring that the symmetric
# difference is empty in one direction: commits in source that are
# NOT first-parent-reachable from source starting at the tag.
# We do this by intersecting:
# A = commits reachable from source but not from tag (full DAG)
# B = commits on the first-parent chain from source down to tag
# and requiring A == B.
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
first_parent_added="$(
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
)"
if [[ "${all_added}" != "${first_parent_added}" ]]; then
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
echo "Commits reachable but not on first-parent chain:"
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
| while read -r sha; do
echo " $(git log -1 --format='%h %s' "${sha}")"
done
exit 1
fi
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
set -euo pipefail
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master. The --state open filter
# is load-bearing: a closed/merged PR with passing checks must not be
# accepted as authorization for a new release.
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid,state \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
exit 1
fi
# Pick the PR matching the expected title
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].number // empty
')"
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].headRefOid // empty
')"
if [[ -z "${pr_number}" ]]; then
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
echo "Found PRs:"
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
exit 1
fi
# The PR's current head commit must equal the SHA the operator gave us.
# This is what closes the door on releasing stale code: if anyone has
# pushed to the branch since the operator validated tests passed, the
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
# resolve step already proved the branch tip == SOURCE_COMMIT; this
# ties that same SHA to the PR that authorizes the release.)
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
exit 1
fi
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
checks_json="$(
gh api \
--paginate \
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
)"
if [[ -z "${checks_json}" ]]; then
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
exit 1
fi
echo "Check runs on ${pr_head_sha}:"
echo "${checks_json}" | jq -s '.'
failing="$(echo "${checks_json}" | jq -s '
map(select(
.status != "completed"
or (.conclusion as $c
| ["success","neutral","skipped"]
| index($c) | not)
))
')"
failing_count="$(echo "${failing}" | jq 'length')"
if [[ "${failing_count}" -gt 0 ]]; then
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
exit 1
fi
echo "All checks have passed on ${pr_head_sha}."
- name: Prepare release branch
id: prepare
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
PATCH: ${{ steps.latest.outputs.patch }}
run: |
set -euo pipefail
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
# and we'll create it from the latest stable tag. If patch > 0, it must
# already exist and its tip must equal the latest stable tag commit (i.e.
# the previous patch release).
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
current_tip="$(git rev-parse HEAD)"
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
echo "::error::Refusing to release on top of a divergent branch."
exit 1
fi
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
else
if [[ "${PATCH}" != "0" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
exit 1
fi
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
fi
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
# --ff-only guarantees no merge commit is created. If a fast-forward is
# not possible (i.e. the release branch has commits the source branch
# doesn't), the merge will fail and we abort. Because we already validated
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
#
# We merge the operator-provided SHA, not the branch ref, so a push to
# the branch in the window between resolve and now cannot smuggle new
# commits into the release.
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
- name: Bump version files
env:
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
run: |
set -euo pipefail
if [[ ! -f comfyui_version.py ]]; then
echo "::error::comfyui_version.py not found in repo root."
exit 1
fi
if [[ ! -f pyproject.toml ]]; then
echo "::error::pyproject.toml not found in repo root."
exit 1
fi
# Replace the version string in comfyui_version.py.
# Expected format: __version__ = "X.Y.Z"
python3 - "$NEW_VERSION_NO_V" <<'PY'
import re, sys, pathlib
new = sys.argv[1]
p = pathlib.Path("comfyui_version.py")
src = p.read_text()
new_src, n = re.subn(
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find __version__ assignment in comfyui_version.py")
p.write_text(new_src)
p = pathlib.Path("pyproject.toml")
src = p.read_text()
# Replace the first `version = "..."` inside [project] or [tool.poetry].
new_src, n = re.subn(
r'(?m)^(version\s*=\s*")[^"]+(")',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find version assignment in pyproject.toml")
p.write_text(new_src)
PY
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
git --no-pager diff -- comfyui_version.py pyproject.toml
- name: Commit version bump and tag release
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
git add comfyui_version.py pyproject.toml
git commit -m "ComfyUI ${NEW_VERSION}"
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
echo "::error::Tag ${NEW_VERSION} already exists locally."
exit 1
fi
git tag "${NEW_VERSION}"
- name: Verify tag does not already exist on origin
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
exit 1
fi
- name: Push release branch and tag
env:
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
# Push the branch first, then the tag. Atomic-ish: if the branch push
# fails we never publish the tag.
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
git push origin "refs/tags/${NEW_VERSION}"
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
- name: Delete remote source branch
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Belt-and-braces: the resolve step already refuses the default branch,
# but never delete the default or the release branch under any
# circumstances.
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
exit 1
fi
# Delete the source branch on origin, but only if its tip is still the
# SHA we released from. If someone pushed new commits to it after we
# resolved it, leave it alone — those commits would be silently lost.
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
if [[ -z "${current_tip}" ]]; then
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
exit 0
fi
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
exit 0
fi
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
echo "Deleted remote branch '${SOURCE_BRANCH}'."
- name: Summary
if: always()
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
run: |
# SOURCE_BRANCH is empty if the resolve step never produced an output
# (e.g. the workflow failed in or before that step). Show a placeholder
# in that case so the summary table still renders cleanly.
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
echo "| Source branch | \`${source_branch_display}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
} >> "$GITHUB_STEP_SUMMARY"

View File

@ -20,7 +20,7 @@
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://discord.com/invite/comfyorg
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI

View File

@ -62,8 +62,6 @@ def get_comfy_package_versions():
def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440
outdated_packages = []
for pkg in get_comfy_package_versions():
installed_str = pkg["installed"]
required_str = pkg["required"]
@ -75,26 +73,19 @@ def check_comfy_packages_versions():
logging.error(f"Failed to check {pkg['name']} version: {e}")
continue
if outdated:
outdated_packages.append((pkg["name"], installed_str, required_str))
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
if outdated_packages:
package_warnings = "\n".join(
f"Installed {name} version {installed} is lower than the recommended version {required}."
for name, installed, required in outdated_packages
)
app.logger.log_startup_warning(
f"""
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
{package_warnings}
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
{get_missing_requirements_message()}
________________________________________________________________________
""".strip()
)
)
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
REQUEST_TIMEOUT = 10 # seconds

View File

@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def cleanup(self):
self.extra_args.pop("base_model", None)
super().cleanup()
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
if context.shape[0] >= 2:
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
swap_cfg_halves = context.shape[0] >= 2
if swap_cfg_halves:
first_half, second_half = context.chunk(2, dim = 0)
context = torch.cat([second_half, first_half], dim = 0)
main_condition = context
t = 1.0 - t
@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
if output.shape[0] >= 2:
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])
else:
return output
if swap_cfg_halves:
first_half, second_half = output.chunk(2, dim = 0)
output = torch.cat([second_half, first_half], dim = 0)
return output

View File

@ -1,6 +1,5 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
lock: object
offset: int
size: int
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
with info.lock:
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
with info.lock:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()

View File

@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@ -27,13 +28,18 @@ import platform
import weakref
import gc
import os
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -204,6 +210,91 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device("cuda", i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device("xpu", i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device("npu", i))
else:
devices.append(get_torch_device())
if exclude_current:
current = get_torch_device()
if current in devices:
devices.remove(current)
return devices
def get_gpu_device_options():
"""Return list of device option strings for node widgets.
Always includes "default" and "cpu". When multiple GPUs are present,
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
"""
options = ["default", "cpu"]
devices = get_all_torch_devices()
if len(devices) > 1:
for i in range(len(devices)):
options.append(f"gpu:{i}")
return options
def resolve_gpu_device_option(option: str):
"""Resolve a device option string to a torch.device.
Returns None for "default" (let the caller use its normal default).
Returns torch.device("cpu") for "cpu".
For "gpu:N", returns the Nth torch device. Falls back to None if
the index is out of range (caller should use default).
"""
if option is None or option == "default":
return None
if option == "cpu":
return torch.device("cpu")
if option.startswith("gpu:"):
try:
idx = int(option[4:])
devices = get_all_torch_devices()
if 0 <= idx < len(devices):
return devices[idx]
else:
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
return None
except (ValueError, IndexError):
logging.warning(f"Invalid device option '{option}', using default.")
return None
logging.warning(f"Unrecognized device option '{option}', using default.")
return None
@contextmanager
def cuda_device_context(device):
"""Context manager that sets torch.cuda.current_device to match *device*.
Used when running operations on a non-default CUDA device so that custom
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
device index. The previous device is restored on exit.
No-op when *device* is not CUDA, has no explicit index, or already matches
the current device.
"""
prev = None
if device.type == "cuda" and device.index is not None:
prev = torch.cuda.current_device()
if prev != device.index:
torch.cuda.set_device(device)
else:
prev = None
try:
yield
finally:
if prev is not None:
torch.cuda.set_device(prev)
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@ -492,9 +583,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []
current_loaded_models: list[LoadedModel] = []
DIRTY_MMAPS = set()
@ -554,7 +649,7 @@ def ensure_pin_registerable(size, evict_active=False):
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@ -562,7 +657,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@ -573,6 +668,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
self.device = model.load_device
@property
def model(self):
@ -1848,7 +1944,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():

View File

@ -23,6 +23,7 @@ import inspect
import logging
import math
import uuid
import copy
from typing import Callable, Optional
import torch
@ -78,12 +79,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@ -329,7 +333,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@ -366,7 +373,8 @@ class ModelPatcher:
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
@ -380,6 +388,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@ -438,19 +448,98 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
if self.cached_patcher_init is not None:
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
n.model = temp_model_patcher.model
else:
n.model = copy.deepcopy(n.model)
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@ -1232,7 +1321,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@ -1286,9 +1375,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
def prepare_state(self, timestep, model_options):
ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
callback(self, timestep, model_options)
if not ignore_multigpu and "multigpu_clones" in model_options:
model_options["ignore_multigpu"] = True
try:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options)
finally:
model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@ -1301,12 +1399,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@ -1318,6 +1422,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):

230
comfy/multigpu.py Normal file
View File

@ -0,0 +1,230 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# only keep model clones that don't go 'past' the intended max_gpu count;
# this prunes any inherited multigpu clones whose load_device is no longer allowed
# when max_gpus is lowered between runs.
allowed_devices = set(limit_extra_devices)
allowed_devices.add(model.load_device)
multigpu_models = model.get_additional_models_with_key("multigpu")
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
if len(new_multigpu_models) != len(multigpu_models):
model.set_additional_models("multigpu", new_multigpu_models)
model.match_multigpu_clones()
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@ -3,6 +3,8 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@ -119,6 +121,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
real_model: BaseModel = model.model
return real_model, conds, models
@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
# aggregation) is duplicated there with per-device scheduling layered on top.
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@ -345,6 +355,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
# accumulation, memory-fit batching, and output aggregation, but adds a
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
# placement, and MultiGPUThreadPool dispatch around the inner loop.
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = list(model_options['multigpu_clones'].keys())
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
# Track conds currently scheduled per device; single source of truth for capacity checks.
device_load: dict[torch.device, int] = {d: 0 for d in devices}
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
def next_available_device(start: int) -> tuple[int, torch.device]:
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
Scans at most len(devices) positions, so this always terminates. Raises if no device
has remaining capacity, which would indicate a bug in conds_per_device accounting.
"""
for offset in range(len(devices)):
i = (start + offset) % len(devices)
if device_load[devices[i]] < conds_per_device:
return i, devices[i]
raise RuntimeError(
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
f"({conds_per_device}) but conds remain to schedule"
)
# run every hooked_to_run separately
index_device = 0
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
index_device, current_device = next_available_device(index_device)
remaining_capacity = conds_per_device - device_load[current_device]
first = to_run[0]
first_shape = first[0][0].shape
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = [to_run.pop(x) for x in to_batch]
device_load[current_device] += len(conds_to_batch)
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
if device_load[current_device] >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
# above are async on CUDA, so the main thread's aggregation
# could race with in-flight transfers. CUDA-only QA has not
# surfaced this in practice, but before extending multigpu
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
# here (guarded by `output_device.type == "cuda"`).
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@ -643,12 +886,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds):
s = model.model_sampling
# Per-device model lookup so multigpu control clones get the matching
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
device_models: dict = {}
patcher = getattr(model, "current_patcher", None)
if patcher is not None:
for p in patcher.get_additional_models_with_key("multigpu"):
device_models[p.load_device] = p.model
for t in range(len(conds)):
x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device, device_cnet in x['control'].multigpu_clones.items():
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@ -891,7 +1143,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@ -900,18 +1154,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
if "patches" in transformer_options:
patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@ -921,8 +1174,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@ -930,7 +1183,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@ -985,16 +1237,32 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
with comfy.model_management.cuda_device_context(device):
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@ -335,41 +335,43 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
with model_management.cuda_device_context(device):
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
@ -383,8 +385,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
@ -446,9 +452,12 @@ class CLIP:
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@ -1026,50 +1035,52 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -1087,20 +1098,21 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
@ -1113,44 +1125,46 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
samples = self.encode_tiled_(pixel_samples)
return samples
@ -1176,26 +1190,27 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
with model_management.cuda_device_context(self.device):
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
@ -1710,10 +1725,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
@ -1742,7 +1755,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
@ -1782,13 +1795,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
vae_device = model_options.get("load_device", None)
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
@ -1872,7 +1887,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
@ -1897,7 +1912,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
weight_dtype = None

View File

@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
sd[name] = tensor

View File

@ -1,5 +1,7 @@
from enum import Enum
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
@ -9,76 +11,44 @@ class Rodin3DGenerateRequest(BaseModel):
material: str = Field(..., description="The material type.")
quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: bool | None = Field(None, description="")
class Rodin3DGen25Request(BaseModel):
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
seed: int | None = Field(None, description="0-65535.")
material: str | None = Field(None, description="PBR | Shaded | All | None.")
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
quality_override: int | None = Field(None, description="Mesh face count override.")
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
height: int | None = Field(None, description="Approximate model height in cm.")
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
TAPose: Optional[bool] = Field(None, description="")
class GenerateJobsData(BaseModel):
uuids: list[str] = Field(..., description="str LIST")
uuids: List[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: str | None = Field(None, description="Return message.")
prompt: str | None = Field(None, description="Generated Prompt from image.")
submit_time: str | None = Field(None, description="Submit Time")
uuid: str | None = Field(None, description="Task str")
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(..., description="Status Currently")
status: JobStatus = Field(...,description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: list[JobItem] = Field(..., description="Job status List")
jobs: List[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3065,6 +3066,7 @@ class KlingVideoNode(IO.ComfyNode):
cls,
ApiEndpoint(path=poll_path),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3190,6 +3192,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -5,37 +5,32 @@ Rodin API docs: https://developer.hyper3d.ai/
"""
from inspect import cleandoc
import folder_paths as comfy_paths
import os
import logging
import math
import os
from inspect import cleandoc
from io import BytesIO
from typing import Any
import aiohttp
from PIL import Image
from typing_extensions import override
import folder_paths as comfy_paths
from comfy_api.latest import IO, ComfyExtension, Types
from PIL import Image
from comfy_api_nodes.apis.rodin import (
JobStatus,
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
Rodin3DGen25Request,
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
JobStatus,
)
from comfy_api_nodes.util import (
sync_op,
poll_op,
ApiEndpoint,
download_url_to_bytesio,
download_url_to_file_3d,
poll_op,
sync_op,
validate_string,
)
from comfy_api.latest import ComfyExtension, IO, Types
COMMON_PARAMETERS = [
IO.Int.Input(
@ -56,30 +51,40 @@ COMMON_PARAMETERS = [
]
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
"4K-Quad": ("Quad", 4000),
"8K-Quad": ("Quad", 8000),
"18K-Quad": ("Quad", 18000),
"50K-Quad": ("Quad", 50000),
"200K-Quad": ("Quad", 200000),
"2K-Triangle": ("Raw", 2000),
"20K-Triangle": ("Raw", 20000),
"150K-Triangle": ("Raw", 150000),
"200K-Triangle": ("Raw", 200000),
"500K-Triangle": ("Raw", 500000),
"1M-Triangle": ("Raw", 1000000),
}
def get_quality_mode(poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
def get_quality_mode(poly_count: str) -> tuple[str, int]:
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
"""
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
@ -91,8 +96,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype("uint8")
image = Image.fromarray(array, "RGB")
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
@ -107,7 +112,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
@ -140,9 +145,11 @@ async def create_generate_task(
TAPose=ta_pose,
),
files=[
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
for image in images
if image is not None
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
)
@ -170,7 +177,6 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
if not response.jobs:
return None
@ -208,7 +214,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
model_file_path = None
file_3d = None
for i in url_list.items:
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
if i.name.lower().endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
@ -483,16 +489,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.Combo.Input(
"Polygon_count",
options=[
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
],
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
@ -545,566 +542,6 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.NodeOutput(model_path, file_3d)
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
Booleans --> "true"/"false". Lists --> one field per element.
"""
form = aiohttp.FormData(default_to_multipart=True)
for key, value in data.items():
if value is None:
continue
if isinstance(value, bool):
form.add_field(key, "true" if value else "false")
elif isinstance(value, list):
for item in value:
form.add_field(key, str(item))
elif isinstance(value, (bytes, bytearray)):
form.add_field(key, value)
else:
form.add_field(key, str(value))
return form
async def _create_gen25_task(
cls: type[IO.ComfyNode],
request: Rodin3DGen25Request,
images: list | None,
) -> tuple[str, str]:
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
if images is not None and len(images) > 5:
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
files = None
if images:
files = [
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
)
for image in images
if image is not None
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
response_model=Rodin3DGenerateResponse,
data=request,
files=files,
content_type="multipart/form-data",
multipart_parser=_rodin_multipart_parser,
)
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
return response.uuid, response.jobs.subscription_key
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
async def _download_gen25_files(
download_list: Rodin3DDownloadResponse,
task_uuid: str,
geometry_file_format: str,
) -> Types.File3D | None:
"""Download every file in the list; return the File3D matching the chosen format."""
folder_name = f"Rodin3D_Gen25_{task_uuid}"
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
os.makedirs(save_dir, exist_ok=True)
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
file_3d: Types.File3D | None = None
for item in download_list.items:
file_path = os.path.join(save_dir, item.name)
ext = os.path.splitext(item.name.lower())[1]
# Prefer the file matching the user's chosen format; fall back below.
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
with open(file_path, "wb") as f:
f.write(file_3d.get_bytes())
continue
await download_url_to_bytesio(item.url, file_path)
# If the chosen format wasn't found, surface any model file we did get.
if file_3d is None:
for item in download_list.items:
ext = os.path.splitext(item.name.lower())[1]
if ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
break
return file_3d
_MODE_REGULAR = "Regular"
_MODE_FAST = "Fast"
_MODE_EXTREME_HIGH = "Extreme-High"
_REGULAR_POLY_OPTIONS = [
"Default",
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
"1M-Triangle",
]
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
return IO.DynamicCombo.Input(
name,
options=[
IO.DynamicCombo.Option(
_MODE_REGULAR,
[
IO.Combo.Input(
"tier",
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
default="Gen-2.5-High",
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
),
IO.Combo.Input(
"polygon_count",
options=_REGULAR_POLY_OPTIONS,
default="Default",
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
),
],
),
IO.DynamicCombo.Option(
_MODE_FAST,
[
IO.Combo.Input(
"tier",
options=[
"Gen-2.5-Extreme-Low",
"Gen-2.5-Low",
"Gen-2.5-Medium",
"Gen-2.5-High",
],
default="Gen-2.5-Low",
),
IO.Int.Input(
"mesh_faces",
default=20000,
min=1000,
max=20000,
display_mode=IO.NumberDisplay.number,
tooltip="Mesh face count (1K-20K in Fast mode).",
),
],
),
IO.DynamicCombo.Option(
_MODE_EXTREME_HIGH,
[
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
IO.Int.Input(
"mesh_faces",
default=1000000,
min=20000,
max=2000000,
display_mode=IO.NumberDisplay.number,
tooltip=(
"Mesh face count. Raw mode: 20K-2M. "
"Quad mode: keep under 200K (upstream may reject higher values)."
),
),
IO.Boolean.Input(
"is_micro",
default=False,
tooltip="Enable micro detail (Extreme-High only).",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode. Enhances generative robustness.",
),
],
),
],
tooltip=(
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
"Extreme-High = 20K-2M faces with optional micro details."
),
)
def _build_common_inputs(*, include_image_only: bool) -> list:
inputs: list = [
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
IO.Combo.Input(
"texture_mode",
options=_TEXTURE_MODE_OPTIONS,
default="Default",
optional=True,
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
optional=True,
),
IO.Boolean.Input(
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
),
IO.Boolean.Input(
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
),
IO.Boolean.Input(
"texture_delight",
default=False,
optional=True,
advanced=True,
tooltip="Remove baked lighting from textures.",
),
]
if include_image_only:
inputs.append(
IO.Boolean.Input(
"use_original_alpha",
default=False,
optional=True,
advanced=True,
tooltip="Preserve image transparency.",
)
)
inputs.extend(
[
IO.Boolean.Input(
"addon_highpack",
default=False,
optional=True,
advanced=True,
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
),
IO.Int.Input(
"bbox_width",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
),
IO.Int.Input(
"bbox_height",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box height (Z axis).",
),
IO.Int.Input(
"bbox_length",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box length (X axis).",
),
IO.Int.Input(
"height_cm",
default=0,
min=0,
max=10000,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Approximate model height in centimeters (0 to skip).",
),
]
)
return inputs
_PRICE_EXPR = """
(
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
{"type":"usd","usd": $total}
)
"""
def _resolve_mode_params(mode_input: dict) -> dict:
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
Missing keys mean "do not send" (so we don't override server defaults).
"""
selected = mode_input["mode"]
out: dict = {}
if selected == _MODE_REGULAR:
out["tier"] = mode_input["tier"]
polygon = mode_input.get("polygon_count", "Default")
if polygon != "Default":
mesh_mode, faces = get_quality_mode(polygon)
out["mesh_mode"] = mesh_mode
out["quality_override"] = faces
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
elif selected == _MODE_FAST:
out["tier"] = mode_input["tier"]
out["mesh_mode"] = "Raw"
out["quality_override"] = int(mode_input["mesh_faces"])
elif selected == _MODE_EXTREME_HIGH:
out["tier"] = "Gen-2.5-Extreme-High"
out["mesh_mode"] = mode_input["mesh_mode"]
out["quality_override"] = int(mode_input["mesh_faces"])
if mode_input.get("is_micro"):
out["is_micro"] = True
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
return out
def _build_request(
*,
mode_input: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
prompt: str | None = None,
use_original_alpha: bool = False,
) -> Rodin3DGen25Request:
mode_params = _resolve_mode_params(mode_input)
bbox = None
if bbox_width and bbox_height and bbox_length:
bbox = [bbox_width, bbox_height, bbox_length]
return Rodin3DGen25Request(
tier=mode_params["tier"],
prompt=prompt or None,
seed=seed,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=None if texture_mode == "Default" else texture_mode,
mesh_mode=mode_params.get("mesh_mode"),
quality_override=mode_params.get("quality_override"),
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
bbox_condition=bbox,
height=height_cm or None,
TAPose=TAPose or None,
hd_texture=hd_texture or None,
texture_delight=texture_delight or None,
is_micro=mode_params.get("is_micro"),
use_original_alpha=use_original_alpha or None,
addons=["HighPack"] if addon_highpack else None,
)
class Rodin3D_Gen25_Image(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
tooltip="1-5 images. The first image is used for materials when multi-view.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=True),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
images: IO.Autogrow.Type,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
use_original_alpha: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
image_tensors = [img for img in images.values() if img is not None]
if not image_tensors:
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
flat_images: list = []
for tensor in image_tensors:
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
for i in range(tensor.shape[0]):
flat_images.append(tensor[i])
else:
flat_images.append(tensor)
if len(flat_images) > 5:
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=None,
use_original_alpha=use_original_alpha,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3D_Gen25_Text(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the 3D model.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=False),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
prompt: str,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=prompt,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1114,8 +551,6 @@ class Rodin3DExtension(ComfyExtension):
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
Rodin3D_Gen25_Image,
Rodin3D_Gen25_Text,
]

View File

@ -8,82 +8,6 @@ from comfy_api.latest import _io
MISSING = object()
class NotNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComfyNotNode",
display_name="Not",
category="utils/logic",
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
search_aliases=["invert", "toggle", "negate", "flip boolean"],
inputs=[
io.AnyType.Input("value"),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
return io.NodeOutput(not value)
class AndNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyAndNode",
display_name="And",
category="utils/logic",
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["all", "every"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(all(values.values()))
class OrNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyOrNode",
display_name="Or",
category="utils/logic",
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["any", "some"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(any(values.values()))
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -91,7 +15,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -122,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -212,7 +136,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="utils/logic",
category="logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -250,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -270,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -289,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="utils/logic",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -306,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="utils/logic",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -322,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="utils/logic",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@ -337,9 +261,6 @@ class LogicExtension(ComfyExtension):
return [
SwitchNode,
CustomComboNode,
NotNode,
AndNode,
OrNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode,

View File

@ -182,7 +182,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
),
io.Combo.Input(
"device",
options=["default", "cpu"],
options=comfy.model_management.get_gpu_device_options(),
advanced=True,
)
],
@ -197,8 +197,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="utils",
category="logic",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -0,0 +1,100 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode):
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_WorkUnits",
display_name="MultiGPU CFG Split",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Model.Input("model"),
io.Int.Input("max_gpus", default=2, min=1, step=1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
return io.NodeOutput(model)
class MultiGPUOptionsNode(io.ComfyNode):
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
round-robin via next_available_device(). Before re-enabling this node, wire its
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
which already implements the proportional split) so the input actually affects work
distribution.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_Options",
display_name="MultiGPU Options",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Int.Input("device_index", default=0, min=0, max=64),
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
],
outputs=[
io.Custom("GPU_OPTIONS").Output(),
],
)
@classmethod
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
else:
gpu_options = gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return io.NodeOutput(gpu_options)
class MultiGPUExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
MultiGPUCFGSplitNode,
# MultiGPUOptionsNode,
]
async def comfy_entrypoint() -> MultiGPUExtension:
return MultiGPUExtension()

View File

@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="utils",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],

View File

@ -23,6 +23,69 @@ class ImageOnlyCheckpointLoader:
return (out[0], out[3], out[2])
class ImageOnlyCheckpointLoaderDevice:
@classmethod
def INPUT_TYPES(s):
device_options = comfy.model_management.get_gpu_device_options()
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
},
"optional": {
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
"clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}),
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
@classmethod
def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"):
return True
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
if resolved_model is not None:
if resolved_model.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved_model
else:
model_options["load_device"] = resolved_model
cv_model_options = {}
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device)
if resolved_clip is not None:
if resolved_clip.type == "cpu":
cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip
else:
cv_model_options["load_device"] = resolved_clip
# VAE device is passed via model_options["load_device"] which
# load_state_dict_guess_config forwards to the VAE constructor.
# If vae_device differs from model_device, we override after loading.
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model_patcher, clip, vae, clip_vision = out[:4]
# Apply VAE device override if it differs from the model device
if resolved_vae is not None and vae is not None:
vae.device = resolved_vae
if resolved_vae.type == "cpu":
offload = resolved_vae
else:
offload = comfy.model_management.vae_offload_device()
vae.patcher.load_device = resolved_vae
vae.patcher.offload_device = offload
return (model_patcher, clip_vision, vae)
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
@ -149,6 +212,7 @@ class ConditioningSetAreaPercentageVideo:
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
@ -158,6 +222,7 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Load Checkpoint Image Only (img2vid model)",
"ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)",
"VideoLinearCFGGuidance": "Video Linear CFG Guidance",
"VideoTriangleCFGGuidance": "Video Triangle CFG Guidance",
}

View File

@ -200,7 +200,7 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import torch
import comfy.utils
import execution
@ -218,7 +218,7 @@ import comfy.model_patcher
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':

128
nodes.py
View File

@ -608,6 +608,73 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
class CheckpointLoaderDevice:
@classmethod
def INPUT_TYPES(s):
device_options = comfy.model_management.get_gpu_device_options()
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
},
"optional": {
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
@classmethod
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
return True
def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
if resolved_model is not None:
if resolved_model.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved_model
else:
model_options["load_device"] = resolved_model
te_model_options = {}
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
if resolved_clip is not None:
if resolved_clip.type == "cpu":
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
else:
te_model_options["load_device"] = resolved_clip
# VAE device is passed via model_options["load_device"] which
# load_state_dict_guess_config forwards to the VAE constructor.
# If vae_device differs from model_device, we override after loading.
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
model_patcher, clip, vae = out[:3]
# Apply VAE device override if it differs from the model device
if resolved_vae is not None and vae is not None:
vae.device = resolved_vae
if resolved_vae.type == "cpu":
offload = resolved_vae
else:
offload = comfy.model_management.vae_offload_device()
vae.patcher.load_device = resolved_vae
vae.patcher.offload_device = offload
return (model_patcher, clip, vae)
class DiffusersLoader:
SEARCH_ALIASES = ["load diffusers model"]
@ -786,14 +853,21 @@ class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(s), )}}
return {"required": { "vae_name": (s.vae_list(s), )},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "loaders"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
#TODO: scale factor?
def load_vae(self, vae_name):
def load_vae(self, vae_name, device="default"):
metadata = None
if vae_name == "pixel_space":
sd = {}
@ -811,7 +885,8 @@ class VAELoader:
metadata = {"tae_latent_channels": 128}
else:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid()
return (vae,)
@ -937,13 +1012,20 @@ class UNETLoader:
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
},
"optional": {
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_unet(self, unet_name, weight_dtype, device="default"):
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
@ -953,6 +1035,13 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
@ -964,7 +1053,7 @@ class CLIPLoader:
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -973,12 +1062,20 @@ class CLIPLoader:
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
@ -992,7 +1089,7 @@ class DualCLIPLoader:
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -1001,6 +1098,10 @@ class DualCLIPLoader:
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
@classmethod
def VALIDATE_INPUTS(cls, device="default"):
return True
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@ -1008,8 +1109,12 @@ class DualCLIPLoader:
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,)
@ -2072,6 +2177,7 @@ NODE_CLASS_MAPPINGS = {
"InpaintModelConditioning": InpaintModelConditioning,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderDevice": CheckpointLoaderDevice,
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
@ -2089,6 +2195,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# Loaders
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint",
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
"VAELoader": "Load VAE",
"LoraLoader": "Load LoRA (Model and CLIP)",
"LoraLoaderModelOnly": "Load LoRA",
@ -2389,6 +2496,7 @@ async def init_builtin_extra_nodes():
"nodes_lt_audio.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo==0.4.3
comfy-aimdo==0.4.4
requests
simpleeval>=1.0.0
blake3

View File

@ -646,18 +646,37 @@ class PromptServer():
@routes.get("/system_stats")
async def system_stats(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
primary_device = comfy.model_management.get_torch_device()
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
installed_templates_version = FrontendManager.get_installed_templates_version()
required_templates_version = FrontendManager.get_required_templates_version()
comfy_package_versions = FrontendManager.get_comfy_package_versions()
# Report every torch device visible to multigpu, with the primary
# device first so existing clients that read devices[0] keep working.
torch_devices = comfy.model_management.get_all_torch_devices()
if primary_device in torch_devices:
torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device]
else:
torch_devices = [primary_device] + list(torch_devices)
device_entries = []
for d in torch_devices:
vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True)
device_entries.append({
"name": comfy.model_management.get_torch_device_name(d),
"type": d.type,
"index": d.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
})
system_stats = {
"system": {
"os": sys.platform,
@ -673,17 +692,7 @@ class PromptServer():
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
"devices": device_entries
}
return web.json_response(system_stats)