mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 16:36:41 +08:00
Compare commits
137 Commits
feat/gpt-i
...
worksplit-
| Author | SHA1 | Date | |
|---|---|---|---|
| 3931ef6a1f | |||
| 767b4ee099 | |||
| 89d4964cf0 | |||
| 48acefc923 | |||
| eae101da07 | |||
| 17fe23868a | |||
| bf0e7bd246 | |||
| bae191c294 | |||
| b502bcfff9 | |||
| 37deccb0d4 | |||
| 4b8fa98c3c | |||
| f0d550bd02 | |||
| 48deb15c0e | |||
| 4b93c4360f | |||
| da3864436c | |||
| b418fb1582 | |||
| 20803749c3 | |||
| 3fab720be9 | |||
| afdddcee66 | |||
| 1d8e379f41 | |||
| 5f4fcd19e7 | |||
| d52dcbc88f | |||
| 84f465e791 | |||
| be35378986 | |||
| f410d28b33 | |||
| f4b99bc623 | |||
| df2fd4c869 | |||
| 4661d1db5a | |||
| b326a544d5 | |||
| d89dd5f0b0 | |||
| 8cbbf0be6c | |||
| c2115a4bac | |||
| bb44c2ecb9 | |||
| efcd8280d6 | |||
| 9e9c129cd0 | |||
| ac14ee68c0 | |||
| 2c8f485434 | |||
| 383f9b34cb | |||
| b0741c7e5b | |||
| 1489399cb5 | |||
| 3677943fa5 | |||
| cfb63bfcd7 | |||
| 962c3c832c | |||
| 6ea69369ce | |||
| b4f559b34d | |||
| df122a7dba | |||
| 67e906aa64 | |||
| 382f84a826 | |||
| 9cca36fa2b | |||
| 5d5024296d | |||
| 3b90a30178 | |||
| 3c4104652b | |||
| 9855baaab3 | |||
| d53479a197 | |||
| 443a795850 | |||
| 431dec8e53 | |||
| 44e053c26d | |||
| 1ae98932f1 | |||
| 0336b0ace8 | |||
| 8ae25235ec | |||
| 9726eac475 | |||
| 272e8d42c1 | |||
| 6211d2be5a | |||
| 8be711715c | |||
| b5cccf1325 | |||
| 2a54a904f4 | |||
| ed6f92c975 | |||
| adc66c0698 | |||
| ccd5c01e5a | |||
| 2fa9affcc1 | |||
| 407a5a656f | |||
| 9ce9ff8ef8 | |||
| 63567c0ce8 | |||
| a786ce5ead | |||
| 4879b47648 | |||
| 5ccec33c22 | |||
| 219d3cd0d0 | |||
| c4ba399475 | |||
| cc928a786d | |||
| 6e144b98c4 | |||
| 6dca17bd2d | |||
| 5080105c23 | |||
| 093914a247 | |||
| 605893d3cf | |||
| 048f4f0b3a | |||
| d2504fb701 | |||
| b03763bca6 | |||
| 476aa79b64 | |||
| 441cfd1a7a | |||
| 99a5c1068a | |||
| 02747cde7d | |||
| 0b3233b4e2 | |||
| eda866bf51 | |||
| e3298b84de | |||
| c7feef9060 | |||
| 51af7fa1b4 | |||
| 46969c380a | |||
| 5db4277449 | |||
| 02a4d0ad7d | |||
| ef137ac0b6 | |||
| 328d4f16a9 | |||
| bdbcb85b8d | |||
| 6c9e94bae7 | |||
| bfce723311 | |||
| 31f5458938 | |||
| 2145a202eb | |||
| 25818dc848 | |||
| 198953cd08 | |||
| ec16ee2f39 | |||
| d5088072fb | |||
| 8d4b50158e | |||
| e88c6c03ff | |||
| d3cf2b7b24 | |||
| 7448f02b7c | |||
| 871258aa72 | |||
| 66838ebd39 | |||
| 7333281698 | |||
| 3cd4c5cb0a | |||
| 11c6d56037 | |||
| 216fea15ee | |||
| 58bf8815c8 | |||
| 1b38f5bf57 | |||
| 2724ac4a60 | |||
| f48f90e471 | |||
| 6463c39ce0 | |||
| 0a7e2ae787 | |||
| 03a97b604a | |||
| 4446c86052 | |||
| 8270ff312f | |||
| db2d7ad9ba | |||
| 6620d86318 | |||
| 111fd0cadf | |||
| 776aa734e1 | |||
| 5a2ad032cb | |||
| d44295ef71 | |||
| bf21be066f | |||
| 72bbf49349 |
@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@ -64,6 +65,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@ -77,7 +90,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@ -85,6 +98,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -111,17 +125,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@ -130,7 +165,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@ -906,6 +949,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@ -607,9 +607,14 @@ class HunYuanDiTPlain(nn.Module):
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", [])
|
||||
swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1}
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = context.chunk(2, dim = 0)
|
||||
context = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
@ -657,5 +662,8 @@ class HunYuanDiTPlain(nn.Module):
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = output.chunk(2, dim = 0)
|
||||
output = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
return output
|
||||
|
||||
@ -4,6 +4,9 @@ import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
@ -40,6 +43,30 @@ class AudioVAEComponentConfig:
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
@ -105,17 +132,23 @@ class AudioPreprocessor:
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, metadata: dict):
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
@ -135,12 +168,18 @@ class AudioVAE(torch.nn.Module):
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio
|
||||
waveform_sample_rate = sample_rate
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
if waveform.shape[1] == 1:
|
||||
@ -151,7 +190,7 @@ class AudioVAE(torch.nn.Module):
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=waveform.device
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
@ -165,13 +204,17 @@ class AudioVAE(torch.nn.Module):
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return waveform
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@ -27,11 +28,16 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@ -206,6 +212,89 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
devices.remove(get_torch_device())
|
||||
return devices
|
||||
|
||||
def get_gpu_device_options():
|
||||
"""Return list of device option strings for node widgets.
|
||||
|
||||
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||
"""
|
||||
options = ["default", "cpu"]
|
||||
devices = get_all_torch_devices()
|
||||
if len(devices) > 1:
|
||||
for i in range(len(devices)):
|
||||
options.append(f"gpu:{i}")
|
||||
return options
|
||||
|
||||
def resolve_gpu_device_option(option: str):
|
||||
"""Resolve a device option string to a torch.device.
|
||||
|
||||
Returns None for "default" (let the caller use its normal default).
|
||||
Returns torch.device("cpu") for "cpu".
|
||||
For "gpu:N", returns the Nth torch device. Falls back to None if
|
||||
the index is out of range (caller should use default).
|
||||
"""
|
||||
if option is None or option == "default":
|
||||
return None
|
||||
if option == "cpu":
|
||||
return torch.device("cpu")
|
||||
if option.startswith("gpu:"):
|
||||
try:
|
||||
idx = int(option[4:])
|
||||
devices = get_all_torch_devices()
|
||||
if 0 <= idx < len(devices):
|
||||
return devices[idx]
|
||||
else:
|
||||
logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.")
|
||||
return None
|
||||
except (ValueError, IndexError):
|
||||
logging.warning(f"Invalid device option '{option}', using default.")
|
||||
return None
|
||||
logging.warning(f"Unrecognized device option '{option}', using default.")
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def cuda_device_context(device):
|
||||
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||
|
||||
Used when running operations on a non-default CUDA device so that custom
|
||||
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||
device index. The previous device is restored on exit.
|
||||
|
||||
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||
the current device.
|
||||
"""
|
||||
prev = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev = torch.cuda.current_device()
|
||||
if prev != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev is not None:
|
||||
torch.cuda.set_device(prev)
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -494,9 +583,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@ -529,7 +622,7 @@ def module_mmap_residency(module, free=False):
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@ -537,7 +630,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@ -548,6 +641,7 @@ class LoadedModel:
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@ -1794,7 +1888,34 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
|
||||
@ -23,6 +23,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import copy
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -75,12 +76,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@ -272,7 +276,10 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@ -312,7 +319,8 @@ class ModelPatcher:
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
@ -326,6 +334,8 @@ class ModelPatcher:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
@ -384,19 +394,98 @@ class ModelPatcher:
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
if self.cached_patcher_init is not None:
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
n.model = temp_model_patcher.model
|
||||
else:
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@ -1171,7 +1260,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@ -1225,9 +1314,13 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options, ignore_multigpu)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@ -1240,12 +1333,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@ -1257,6 +1356,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
|
||||
230
comfy/multigpu.py
Normal file
230
comfy/multigpu.py
Normal file
@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
def __init__(self, devices: list[torch.device]):
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||
|
||||
for device in devices:
|
||||
wq = queue.Queue()
|
||||
rq = queue.Queue()
|
||||
self._work_queues[device] = wq
|
||||
self._result_queues[device] = rq
|
||||
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||
t.start()
|
||||
self._workers.append(t)
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
return
|
||||
result_q.put((None, e))
|
||||
return
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
break
|
||||
fn, args, kwargs = item
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||
self._work_queues[device].put((fn, args, kwargs))
|
||||
|
||||
def get_result(self, device: torch.device):
|
||||
return self._result_queues[device].get()
|
||||
|
||||
@property
|
||||
def devices(self) -> list[torch.device]:
|
||||
return list(self._work_queues.keys())
|
||||
|
||||
def shutdown(self):
|
||||
for wq in self._work_queues.values():
|
||||
wq.put(None) # sentinel
|
||||
for t in self._workers:
|
||||
t.join(timeout=5.0)
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
||||
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
# new_multigpu_models = []
|
||||
# for m in multigpu_models:
|
||||
# if m.load_device in limit_extra_devices:
|
||||
# new_multigpu_models.append(m)
|
||||
# model.set_additional_models("multigpu", new_multigpu_models)
|
||||
# persist skip_devices for use in sampling code
|
||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
@ -3,6 +3,8 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@ -20,7 +20,6 @@ try:
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@ -118,6 +120,47 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@ -142,7 +185,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@ -154,7 +198,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
real_model: BaseModel = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -200,3 +244,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@ -16,6 +18,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.multigpu
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -153,6 +156,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@ -212,7 +217,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -244,7 +251,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -345,6 +352,212 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep.to(device)
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
def _handle_batch_pooled(device, batch_tuple):
|
||||
worker_results = []
|
||||
_handle_batch(device, batch_tuple, worker_results)
|
||||
return worker_results
|
||||
|
||||
results: list[thread_result] = []
|
||||
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||
|
||||
# Submit all GPU work to pool threads
|
||||
pool_devices = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
if thread_pool is not None:
|
||||
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||
pool_devices.append(device)
|
||||
else:
|
||||
# Fallback: no pool, run everything on main thread
|
||||
_handle_batch(device, batch_tuple, results)
|
||||
|
||||
# Collect results from pool workers
|
||||
for device in pool_devices:
|
||||
worker_results, error = thread_pool.get_result(device)
|
||||
if error is not None:
|
||||
raise error
|
||||
results.extend(worker_results)
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@ -649,6 +862,8 @@ def pre_run_control(model, conds):
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device_cnet in x['control'].multigpu_clones.values():
|
||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -891,7 +1106,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@ -900,18 +1117,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@ -921,8 +1137,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@ -930,7 +1146,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@ -985,16 +1200,32 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
# Create persistent thread pool for all GPU devices (main + extras)
|
||||
if multigpu_patchers:
|
||||
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
with comfy.model_management.cuda_device_context(device):
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
337
comfy/sd.py
337
comfy/sd.py
@ -12,7 +12,6 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.lightricks.vae.audio_vae
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
@ -325,41 +324,43 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
with model_management.cuda_device_context(device):
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
@ -373,8 +374,12 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
@ -429,9 +434,12 @@ class CLIP:
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@ -806,24 +814,6 @@ class VAE:
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = self.first_stage_model.latent_channels
|
||||
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
|
||||
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -966,50 +956,52 @@ class VAE:
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1027,20 +1019,21 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -1053,44 +1046,46 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1116,26 +1111,27 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1623,10 +1619,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
return out
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
@ -1655,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
@ -1695,13 +1688,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
vae_device = model_options.get("load_device", None)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
@ -1785,7 +1780,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
@ -1810,7 +1805,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
|
||||
@ -158,17 +158,10 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model and output resolution.
|
||||
# Seedance 2.0 reference video pixel count limits per model.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
"1080p": {"min": 409_600, "max": 2_073_600},
|
||||
},
|
||||
"dreamina-seedance-2-0-fast-260128": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@ -35,7 +35,6 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
resize_video_to_pixel_budget,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
@ -70,12 +69,9 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution."""
|
||||
model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
if not model_limits:
|
||||
return
|
||||
limits = model_limits.get(resolution)
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
if not limits:
|
||||
return
|
||||
try:
|
||||
@ -1377,14 +1373,6 @@ def _seedance2_reference_inputs(resolutions: list[str]):
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -1492,23 +1480,10 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = len(reference_videos) > 0
|
||||
|
||||
if model.get("auto_downscale") and reference_videos:
|
||||
max_px = (
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
|
||||
.get(model["resolution"], {})
|
||||
.get("max")
|
||||
)
|
||||
if max_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = resize_video_to_pixel_budget(
|
||||
reference_videos[key], max_px
|
||||
)
|
||||
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
video = reference_videos[key]
|
||||
_validate_ref_video_pixels(video, model_id, model["resolution"], i)
|
||||
_validate_ref_video_pixels(video, model_id, i)
|
||||
try:
|
||||
dur = video.get_duration()
|
||||
if dur < 1.8:
|
||||
|
||||
@ -357,18 +357,13 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||
|
||||
|
||||
def calculate_tokens_price_image_2(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
# https://platform.openai.com/docs/pricing - gpt-image-2: input $8/1M, output $30/1M
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
|
||||
|
||||
|
||||
class OpenAIGPTImage1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1 & 1.5",
|
||||
display_name="OpenAI GPT Image 1.5",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
@ -447,22 +442,14 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$ranges :=
|
||||
$contains($m, "gpt-image-1.5")
|
||||
? {
|
||||
"low": [0.009, 0.016],
|
||||
"medium": [0.037, 0.056],
|
||||
"high": [0.134, 0.240]
|
||||
}
|
||||
: {
|
||||
"low": [0.011, 0.020],
|
||||
"medium": [0.046, 0.070],
|
||||
"high": [0.167, 0.300]
|
||||
};
|
||||
$ranges := {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.046, 0.07],
|
||||
"high": [0.167, 0.3]
|
||||
};
|
||||
$range := $lookup($ranges, widgets.quality);
|
||||
$n := widgets.n;
|
||||
($n = 1)
|
||||
@ -577,261 +564,6 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
_GPT_IMAGE_2_SIZES = [
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1536x1024",
|
||||
"1024x1536",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
]
|
||||
|
||||
|
||||
def _resolve_gpt_image_2_size(size: str, custom_width: int, custom_height: int) -> str:
|
||||
if custom_width <= 0 or custom_height <= 0:
|
||||
return size
|
||||
w, h = custom_width, custom_height
|
||||
if max(w, h) > 3840:
|
||||
raise ValueError(f"Maximum edge length must be ≤ 3840px, got {max(w, h)}")
|
||||
if w % 16 != 0 or h % 16 != 0:
|
||||
raise ValueError(f"Both edges must be multiples of 16px, got {w}x{h}")
|
||||
if max(w, h) / min(w, h) > 3:
|
||||
raise ValueError(f"Long-to-short edge ratio must not exceed 3:1, got {max(w, h) / min(w, h):.2f}:1")
|
||||
total = w * h
|
||||
if total < 655_360 or total > 8_294_400:
|
||||
raise ValueError(f"Total pixels must be between 655,360 and 8,294,400, got {total:,}")
|
||||
return f"{w}x{h}"
|
||||
|
||||
|
||||
class OpenAIGPTImage2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT-Image-2 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image 2",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2**31 - 1,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="not implemented yet in backend",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
default="auto",
|
||||
options=["auto", "low", "medium", "high"],
|
||||
tooltip="Image quality. 'auto' lets the model decide based on the prompt. Square images are typically fastest.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque"],
|
||||
tooltip="Background style. GPT-Image-2 does not support transparent backgrounds.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=_GPT_IMAGE_2_SIZES,
|
||||
tooltip="Output image dimensions. Ignored when custom_width and custom_height are both non-zero.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=3840,
|
||||
step=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Custom output width in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,360–8,294,400.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=3840,
|
||||
step=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Custom output height in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,360–8,294,400.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Number of images to generate per run.",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-2"],
|
||||
default="gpt-image-2",
|
||||
tooltip="Model used for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "num_images"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"low": [0.005, 0.010],
|
||||
"medium": [0.041, 0.060],
|
||||
"high": [0.165, 0.250]
|
||||
};
|
||||
$q := widgets.quality;
|
||||
$n := widgets.num_images;
|
||||
$n := ($n != null and $n != 0) ? $n : 1;
|
||||
$range := $lookup($ranges, $q);
|
||||
$lo := $range ? $range[0] : 0.005;
|
||||
$hi := $range ? $range[1] : 0.250;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $lo, "max_usd": $hi, "format": {"approximate": ($range ? false : true)}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $lo,
|
||||
"max_usd": $hi,
|
||||
"format": {"approximate": ($range ? false : true), "suffix": " x " & $string($n) & "/Run"}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
seed: int = 0,
|
||||
quality: str = "auto",
|
||||
background: str = "auto",
|
||||
image: Input.Image | None = None,
|
||||
mask: Input.Image | None = None,
|
||||
num_images: int = 1,
|
||||
size: str = "auto",
|
||||
custom_width: int = 0,
|
||||
custom_height: int = 0,
|
||||
model: str = "gpt-image-2",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
resolved_size = _resolve_gpt_image_2_size(size, custom_width, custom_height)
|
||||
|
||||
if image is not None:
|
||||
files = []
|
||||
batch_size = image.shape[0]
|
||||
for i in range(batch_size):
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
|
||||
if batch_size == 1:
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=num_images,
|
||||
size=resolved_size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=calculate_tokens_price_image_2,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=num_images,
|
||||
size=resolved_size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=calculate_tokens_price_image_2,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
class OpenAIChatNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
@ -1181,7 +913,6 @@ class OpenAIExtension(ComfyExtension):
|
||||
OpenAIDalle2,
|
||||
OpenAIDalle3,
|
||||
OpenAIGPTImage1,
|
||||
OpenAIGPTImage2,
|
||||
OpenAIChatNode,
|
||||
OpenAIInputFiles,
|
||||
OpenAIChatConfig,
|
||||
|
||||
@ -24,9 +24,8 @@ from comfy_api_nodes.util import (
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
MODELS_MAP = {
|
||||
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
||||
"veo-3.1-generate": "veo-3.1-generate-001",
|
||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
|
||||
"veo-3.1-lite": "veo-3.1-lite-generate-001",
|
||||
"veo-3.1-generate": "veo-3.1-generate-preview",
|
||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
|
||||
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||
}
|
||||
@ -248,8 +247,17 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
"""Generates videos from text prompts using Google's Veo 3 API."""
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo 3 API.
|
||||
|
||||
Supported models:
|
||||
- veo-3.0-generate-001
|
||||
- veo-3.0-fast-generate-001
|
||||
|
||||
This node extends the base Veo node with Veo 3 specific features including
|
||||
audio generation and fixed 8-second duration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -271,13 +279,6 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p", "4k"],
|
||||
default="720p",
|
||||
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
@ -288,11 +289,11 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
IO.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=4,
|
||||
min=8,
|
||||
max=8,
|
||||
step=2,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
@ -331,10 +332,10 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
options=[
|
||||
"veo-3.1-generate",
|
||||
"veo-3.1-fast-generate",
|
||||
"veo-3.1-lite",
|
||||
"veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001",
|
||||
],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
@ -355,111 +356,21 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$a := widgets.generate_audio;
|
||||
$seconds := widgets.duration_seconds;
|
||||
$pps :=
|
||||
$contains($m, "lite")
|
||||
? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
|
||||
: $contains($m, "3.1-fast")
|
||||
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
|
||||
: $contains($m, "3.1-generate")
|
||||
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
|
||||
: $contains($m, "3.0-fast")
|
||||
? ($a ? 0.15 : 0.10)
|
||||
: ($a ? 0.40 : 0.20);
|
||||
{"type":"usd","usd": $pps * $seconds}
|
||||
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
|
||||
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
|
||||
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
|
||||
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
|
||||
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
resolution="720p",
|
||||
negative_prompt="",
|
||||
duration_seconds=8,
|
||||
enhance_prompt=True,
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
model="veo-3.0-generate-001",
|
||||
generate_audio=False,
|
||||
):
|
||||
if "lite" in model and resolution == "4k":
|
||||
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
|
||||
|
||||
model = MODELS_MAP[model]
|
||||
|
||||
instances = [{"prompt": prompt}]
|
||||
if image is not None:
|
||||
image_base64 = tensor_to_base64_string(image)
|
||||
if image_base64:
|
||||
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||
|
||||
parameters = {
|
||||
"aspectRatio": aspect_ratio,
|
||||
"personGeneration": person_generation,
|
||||
"durationSeconds": duration_seconds,
|
||||
"enhancePrompt": True,
|
||||
"generateAudio": generate_audio,
|
||||
}
|
||||
if negative_prompt:
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
if "veo-3.1" in model:
|
||||
parameters["resolution"] = resolution
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters,
|
||||
),
|
||||
)
|
||||
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||
data=VeoGenVidPollRequest(operationName=initial_response.name),
|
||||
poll_interval=9.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
response = poll_response.response
|
||||
filtered_count = response.raiMediaFilteredCount
|
||||
if filtered_count:
|
||||
reasons = response.raiMediaFilteredReasons or []
|
||||
reason_part = f": {reasons[0]}" if reasons else ""
|
||||
raise Exception(
|
||||
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||
)
|
||||
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@ -483,7 +394,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
@ -513,7 +424,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
||||
default="veo-3.1-fast-generate",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
@ -531,20 +443,26 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
|
||||
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
|
||||
};
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$ga := widgets.generate_audio;
|
||||
$ga := (widgets.generate_audio = "true");
|
||||
$seconds := widgets.duration;
|
||||
$pps :=
|
||||
$contains($m, "lite")
|
||||
? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
|
||||
: $contains($m, "fast")
|
||||
? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
|
||||
: ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
|
||||
{"type":"usd","usd": $pps * $seconds}
|
||||
$modelKey :=
|
||||
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
|
||||
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
|
||||
"";
|
||||
$audioKey := $ga ? "audio" : "no_audio";
|
||||
$modelPrices := $lookup($prices, $modelKey);
|
||||
$pps := $lookup($modelPrices, $audioKey);
|
||||
($pps != null)
|
||||
? {"type":"usd","usd": $pps * $seconds}
|
||||
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -564,9 +482,6 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
if "lite" in model and resolution == "4k":
|
||||
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
|
||||
|
||||
model = MODELS_MAP[model]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -604,7 +519,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
poll_interval=9.0,
|
||||
poll_interval=5.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,6 @@ from .conversions import (
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
resize_video_to_pixel_budget,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
@ -91,7 +90,6 @@ __all__ = [
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"resize_video_to_pixel_budget",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
|
||||
@ -129,38 +129,22 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
|
||||
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
|
||||
|
||||
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
|
||||
are rounded down to even values (many codecs require divisible-by-2).
|
||||
"""
|
||||
pixels = src_w * src_h
|
||||
if pixels <= total_pixels:
|
||||
return None
|
||||
scale = math.sqrt(total_pixels / pixels)
|
||||
new_w = max(2, int(src_w * scale))
|
||||
new_h = max(2, int(src_h * scale))
|
||||
new_w -= new_w % 2
|
||||
new_h -= new_h % 2
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels.
|
||||
|
||||
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
|
||||
and is compatible with codecs that require even dimensions (e.g. yuv420p).
|
||||
"""
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||
if dims is None:
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
new_w, new_h = dims
|
||||
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
||||
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
||||
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
height, width = samples.shape[2], samples.shape[3]
|
||||
@ -415,72 +399,6 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
|
||||
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
|
||||
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
|
||||
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
|
||||
out_w, out_h = scale_dims
|
||||
output_buffer = BytesIO()
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
input_source = video.get_stream_source()
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
|
||||
video_stream.width = out_w
|
||||
video_stream.height = out_h
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
|
||||
audio_stream = None
|
||||
for stream in input_container.streams:
|
||||
if isinstance(stream, av.AudioStream):
|
||||
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
break
|
||||
|
||||
for frame in input_container.decode(video=0):
|
||||
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
if audio_stream is not None:
|
||||
input_container.seek(0)
|
||||
for audio_frame in input_container.decode(audio=0):
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output_container.mux(packet)
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
|
||||
|
||||
|
||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
|
||||
@ -3,136 +3,136 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Conditioning.Output()],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return IO.NodeOutput(conditioning)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
IO.Int.Input("bpm", default=120, min=10, max=300),
|
||||
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[IO.Conditioning.Output()],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return IO.NodeOutput(conditioning)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
IO.Int.Input(
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = round((seconds * 48000 / 1920))
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceAudio(IO.ComfyNode):
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("conditioning"),
|
||||
IO.Latent.Input("latent", optional=True),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(),
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
|
||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||
return IO.NodeOutput(conditioning)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
|
||||
@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
|
||||
@ -3,8 +3,9 @@ import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.nodes_audio import VAEEncodeAudio
|
||||
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -27,14 +28,10 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
|
||||
return io.NodeOutput(vae)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
@ -53,8 +50,15 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae) -> io.NodeOutput:
|
||||
return super().execute(audio_vae, audio)
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@ -76,12 +80,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae) -> io.NodeOutput:
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
@ -139,17 +143,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae,
|
||||
audio_vae: AudioVAE,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
@ -184,7 +188,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=["default", "cpu"],
|
||||
options=comfy.model_management.get_gpu_device_options(),
|
||||
advanced=True,
|
||||
)
|
||||
],
|
||||
@ -199,8 +203,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
89
comfy_extras/nodes_multigpu.py
Normal file
89
comfy_extras/nodes_multigpu.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import cleandoc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_WorkUnits",
|
||||
display_name="MultiGPU CFG Split",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_Options",
|
||||
display_name="MultiGPU Options",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Int.Input("device_index", default=0, min=0, max=64),
|
||||
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
|
||||
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Custom("GPU_OPTIONS").Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return io.NodeOutput(gpu_options)
|
||||
|
||||
|
||||
class MultiGPUExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MultiGPUCFGSplitNode,
|
||||
# MultiGPUOptionsNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MultiGPUExtension:
|
||||
return MultiGPUExtension()
|
||||
4
main.py
4
main.py
@ -192,7 +192,7 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@ -210,7 +210,7 @@ import comfy.model_patcher
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
|
||||
128
nodes.py
128
nodes.py
@ -608,6 +608,73 @@ class CheckpointLoaderSimple:
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out[:3]
|
||||
|
||||
|
||||
class CheckpointLoaderDevice:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
device_options = comfy.model_management.get_gpu_device_options()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
},
|
||||
"optional": {
|
||||
"model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}),
|
||||
"clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}),
|
||||
"vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups."
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"):
|
||||
return True
|
||||
|
||||
def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
resolved_model = comfy.model_management.resolve_gpu_device_option(model_device)
|
||||
if resolved_model is not None:
|
||||
if resolved_model.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved_model
|
||||
else:
|
||||
model_options["load_device"] = resolved_model
|
||||
|
||||
te_model_options = {}
|
||||
resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device)
|
||||
if resolved_clip is not None:
|
||||
if resolved_clip.type == "cpu":
|
||||
te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip
|
||||
else:
|
||||
te_model_options["load_device"] = resolved_clip
|
||||
|
||||
# VAE device is passed via model_options["load_device"] which
|
||||
# load_state_dict_guess_config forwards to the VAE constructor.
|
||||
# If vae_device differs from model_device, we override after loading.
|
||||
resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device)
|
||||
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options)
|
||||
model_patcher, clip, vae = out[:3]
|
||||
|
||||
# Apply VAE device override if it differs from the model device
|
||||
if resolved_vae is not None and vae is not None:
|
||||
vae.device = resolved_vae
|
||||
if resolved_vae.type == "cpu":
|
||||
offload = resolved_vae
|
||||
else:
|
||||
offload = comfy.model_management.vae_offload_device()
|
||||
vae.patcher.load_device = resolved_vae
|
||||
vae.patcher.offload_device = offload
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
class DiffusersLoader:
|
||||
SEARCH_ALIASES = ["load diffusers model"]
|
||||
|
||||
@ -807,14 +874,21 @@ class VAELoader:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (s.vae_list(s), )}}
|
||||
return {"required": { "vae_name": (s.vae_list(s), )},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
def load_vae(self, vae_name, device="default"):
|
||||
metadata = None
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
@ -827,7 +901,8 @@ class VAELoader:
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
@ -953,13 +1028,20 @@ class UNETLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
},
|
||||
"optional": {
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype, device="default"):
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
@ -969,6 +1051,13 @@ class UNETLoader:
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
@ -980,7 +1069,7 @@ class CLIPLoader:
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -989,12 +1078,20 @@ class CLIPLoader:
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
@ -1008,7 +1105,7 @@ class DualCLIPLoader:
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
"device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -1017,6 +1114,10 @@ class DualCLIPLoader:
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, device="default"):
|
||||
return True
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
@ -1024,8 +1125,12 @@ class DualCLIPLoader:
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is not None:
|
||||
if resolved.type == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = resolved
|
||||
else:
|
||||
model_options["load_device"] = resolved
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return (clip,)
|
||||
@ -2098,6 +2203,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"InpaintModelConditioning": InpaintModelConditioning,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"CheckpointLoaderDevice": CheckpointLoaderDevice,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
@ -2115,6 +2221,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Loaders
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"CheckpointLoaderDevice": "Load Checkpoint (Device)",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA (Model and CLIP)",
|
||||
"LoraLoaderModelOnly": "Load LoRA",
|
||||
@ -2412,6 +2519,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_lt_audio.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.42.14
|
||||
comfyui-frontend-package==1.42.11
|
||||
comfyui-workflow-templates==0.9.57
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
@ -19,11 +19,11 @@ scipy
|
||||
tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy>=2.0
|
||||
SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.12
|
||||
comfy-aimdo==0.0.213
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
@ -1,246 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from comfy_api_nodes.nodes_openai import (
|
||||
OpenAIGPTImage1,
|
||||
OpenAIGPTImage2,
|
||||
_GPT_IMAGE_2_SIZES,
|
||||
_resolve_gpt_image_2_size,
|
||||
calculate_tokens_price_image_1,
|
||||
calculate_tokens_price_image_1_5,
|
||||
calculate_tokens_price_image_2,
|
||||
)
|
||||
from comfy_api_nodes.apis.openai import OpenAIImageGenerationResponse, Usage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_response(input_tokens: int, output_tokens: int) -> OpenAIImageGenerationResponse:
|
||||
return OpenAIImageGenerationResponse(
|
||||
data=[],
|
||||
usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Price extractor tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_price_image_1_formula():
|
||||
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
|
||||
assert calculate_tokens_price_image_1(response) == pytest.approx(50.0)
|
||||
|
||||
|
||||
def test_price_image_1_5_formula():
|
||||
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
|
||||
assert calculate_tokens_price_image_1_5(response) == pytest.approx(40.0)
|
||||
|
||||
|
||||
def test_price_image_2_formula():
|
||||
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
|
||||
assert calculate_tokens_price_image_2(response) == pytest.approx(38.0)
|
||||
|
||||
|
||||
def test_price_image_2_cheaper_than_1():
|
||||
response = _make_response(input_tokens=500, output_tokens=196)
|
||||
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1(response)
|
||||
|
||||
|
||||
def test_price_image_2_cheaper_output_than_1_5():
|
||||
# gpt-image-2 output rate ($30/1M) is lower than gpt-image-1.5 ($32/1M)
|
||||
response = _make_response(input_tokens=0, output_tokens=1_000_000)
|
||||
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1_5(response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_gpt_image_2_size tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resolve_preset_passthrough_when_custom_zero():
|
||||
# 0/0 means "use size preset"
|
||||
assert _resolve_gpt_image_2_size("1024x1024", 0, 0) == "1024x1024"
|
||||
assert _resolve_gpt_image_2_size("auto", 0, 0) == "auto"
|
||||
assert _resolve_gpt_image_2_size("3840x2160", 0, 0) == "3840x2160"
|
||||
|
||||
|
||||
def test_resolve_preset_passthrough_when_only_one_dim_set():
|
||||
# only one dimension set → still use preset
|
||||
assert _resolve_gpt_image_2_size("auto", 1024, 0) == "auto"
|
||||
assert _resolve_gpt_image_2_size("auto", 0, 1024) == "auto"
|
||||
|
||||
|
||||
def test_resolve_custom_overrides_preset():
|
||||
assert _resolve_gpt_image_2_size("auto", 1024, 1024) == "1024x1024"
|
||||
assert _resolve_gpt_image_2_size("1024x1024", 2048, 1152) == "2048x1152"
|
||||
assert _resolve_gpt_image_2_size("auto", 3840, 2160) == "3840x2160"
|
||||
|
||||
|
||||
def test_resolve_custom_rejects_edge_too_large():
|
||||
with pytest.raises(ValueError, match="3840"):
|
||||
_resolve_gpt_image_2_size("auto", 4096, 1024)
|
||||
|
||||
|
||||
def test_resolve_custom_rejects_non_multiple_of_16():
|
||||
with pytest.raises(ValueError, match="multiple of 16"):
|
||||
_resolve_gpt_image_2_size("auto", 1025, 1024)
|
||||
|
||||
|
||||
def test_resolve_custom_rejects_bad_ratio():
|
||||
with pytest.raises(ValueError, match="ratio"):
|
||||
_resolve_gpt_image_2_size("auto", 3840, 1024) # 3.75:1 > 3:1
|
||||
|
||||
|
||||
def test_resolve_custom_rejects_too_few_pixels():
|
||||
with pytest.raises(ValueError, match="Total pixels"):
|
||||
_resolve_gpt_image_2_size("auto", 16, 16)
|
||||
|
||||
|
||||
def test_resolve_custom_rejects_too_many_pixels():
|
||||
# 3840x2176 exceeds 8,294,400
|
||||
with pytest.raises(ValueError, match="Total pixels"):
|
||||
_resolve_gpt_image_2_size("auto", 3840, 2176)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAIGPTImage1 schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOpenAIGPTImage1Schema:
|
||||
def setup_method(self):
|
||||
self.schema = OpenAIGPTImage1.define_schema()
|
||||
|
||||
def test_node_id(self):
|
||||
assert self.schema.node_id == "OpenAIGPTImage1"
|
||||
|
||||
def test_display_name(self):
|
||||
assert self.schema.display_name == "OpenAI GPT Image 1 & 1.5"
|
||||
|
||||
def test_model_options_exclude_gpt_image_2(self):
|
||||
model_input = next(i for i in self.schema.inputs if i.name == "model")
|
||||
assert "gpt-image-2" not in model_input.options
|
||||
|
||||
def test_model_options_include_legacy_models(self):
|
||||
model_input = next(i for i in self.schema.inputs if i.name == "model")
|
||||
assert "gpt-image-1" in model_input.options
|
||||
assert "gpt-image-1.5" in model_input.options
|
||||
|
||||
def test_has_background_with_transparent(self):
|
||||
bg_input = next(i for i in self.schema.inputs if i.name == "background")
|
||||
assert "transparent" in bg_input.options
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAIGPTImage2 schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOpenAIGPTImage2Schema:
|
||||
def setup_method(self):
|
||||
self.schema = OpenAIGPTImage2.define_schema()
|
||||
|
||||
def test_node_id(self):
|
||||
assert self.schema.node_id == "OpenAIGPTImage2"
|
||||
|
||||
def test_display_name(self):
|
||||
assert self.schema.display_name == "OpenAI GPT Image 2"
|
||||
|
||||
def test_category(self):
|
||||
assert "OpenAI" in self.schema.category
|
||||
|
||||
def test_no_transparent_background(self):
|
||||
bg_input = next(i for i in self.schema.inputs if i.name == "background")
|
||||
assert "transparent" not in bg_input.options
|
||||
|
||||
def test_background_options(self):
|
||||
bg_input = next(i for i in self.schema.inputs if i.name == "background")
|
||||
assert set(bg_input.options) == {"auto", "opaque"}
|
||||
|
||||
def test_quality_options(self):
|
||||
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
|
||||
assert set(quality_input.options) == {"auto", "low", "medium", "high"}
|
||||
|
||||
def test_quality_default_is_auto(self):
|
||||
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
|
||||
assert quality_input.default == "auto"
|
||||
|
||||
def test_all_popular_sizes_present(self):
|
||||
size_input = next(i for i in self.schema.inputs if i.name == "size")
|
||||
for size in ["1024x1024", "1536x1024", "1024x1536", "2048x2048", "2048x1152", "3840x2160", "2160x3840"]:
|
||||
assert size in size_input.options, f"Missing size: {size}"
|
||||
|
||||
def test_no_custom_size_option(self):
|
||||
size_input = next(i for i in self.schema.inputs if i.name == "size")
|
||||
assert "custom" not in size_input.options
|
||||
|
||||
def test_size_default_is_auto(self):
|
||||
size_input = next(i for i in self.schema.inputs if i.name == "size")
|
||||
assert size_input.default == "auto"
|
||||
|
||||
def test_custom_width_and_height_inputs_exist(self):
|
||||
input_names = [i.name for i in self.schema.inputs]
|
||||
assert "custom_width" in input_names
|
||||
assert "custom_height" in input_names
|
||||
|
||||
def test_custom_width_height_default_zero(self):
|
||||
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
|
||||
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
|
||||
assert width_input.default == 0
|
||||
assert height_input.default == 0
|
||||
|
||||
def test_custom_width_height_step_is_16(self):
|
||||
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
|
||||
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
|
||||
assert width_input.step == 16
|
||||
assert height_input.step == 16
|
||||
|
||||
def test_custom_width_height_max_is_3840(self):
|
||||
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
|
||||
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
|
||||
assert width_input.max == 3840
|
||||
assert height_input.max == 3840
|
||||
|
||||
def test_uses_num_images_not_n(self):
|
||||
input_names = [i.name for i in self.schema.inputs]
|
||||
assert "num_images" in input_names
|
||||
assert "n" not in input_names
|
||||
|
||||
def test_model_input_shows_gpt_image_2(self):
|
||||
model_input = next(i for i in self.schema.inputs if i.name == "model")
|
||||
assert model_input.options == ["gpt-image-2"]
|
||||
assert model_input.default == "gpt-image-2"
|
||||
|
||||
def test_has_image_and_mask_inputs(self):
|
||||
input_names = [i.name for i in self.schema.inputs]
|
||||
assert "image" in input_names
|
||||
assert "mask" in input_names
|
||||
|
||||
def test_is_api_node(self):
|
||||
assert self.schema.is_api_node is True
|
||||
|
||||
def test_sizes_match_constant(self):
|
||||
size_input = next(i for i in self.schema.inputs if i.name == "size")
|
||||
assert size_input.options == _GPT_IMAGE_2_SIZES
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAIGPTImage2 execute validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_on_empty_prompt():
|
||||
with pytest.raises(Exception):
|
||||
await OpenAIGPTImage2.execute(prompt=" ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_mask_without_image():
|
||||
import torch
|
||||
mask = torch.ones(1, 64, 64)
|
||||
with pytest.raises(ValueError, match="mask without an input image"):
|
||||
await OpenAIGPTImage2.execute(prompt="test", mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_invalid_custom_size():
|
||||
with pytest.raises(ValueError):
|
||||
await OpenAIGPTImage2.execute(prompt="test", custom_width=4096, custom_height=1024)
|
||||
Reference in New Issue
Block a user