Compare commits

..

44 Commits

Author SHA1 Message Date
e9db9554bd ACTUALLY skip timing checks in CI 2025-09-03 12:30:44 -07:00
98be8e1969 Whoops, reverted a change in the last commit 2025-09-03 12:21:13 -07:00
766ff74207 Just disable timing-related assertions in CI
That way there's no risk of periodic non-deterministic test failures.
2025-09-03 12:18:48 -07:00
b1b5f87534 Add more leeway on async tests for Windows CI 2025-09-02 23:43:15 -07:00
cf45fd1742 Add missing test modules 2025-09-02 23:21:43 -07:00
e7314f49e6 Fix showing progress from other sessions
Because `client_id` was missing from ths `progress_state` message, it
was being sent to all connected sessions. This technically meant that if
someone had a graph with the same nodes, they would see the progress
updates for others.

Also added a test to prevent reoccurance and moved the tests around to
make CI easier to hook up.
2025-09-02 23:17:26 -07:00
e2d1e5dad9 Enable Convolution AutoTuning (#9301) 2025-09-01 20:33:50 -04:00
27e067ce50 Implement the USO subject identity lora. (#9674)
Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso"
and a ReferenceLatent node with the reference image.
2025-09-01 18:54:02 -04:00
9b15155972 Probably not necessary anymore. (#9646) 2025-08-31 01:32:10 -04:00
32a627bf1f SEEDS: update noise decomposition and refactor (#9633)
- Update the decomposition to reflect interval dependency
- Extract phi computations into functions
- Use torch.lerp for interpolation
2025-08-31 00:01:45 -04:00
fe442fac2e convert Primitive nodes to V3 schema (#9372) 2025-08-30 23:21:58 -04:00
d2c502e629 convert nodes_stability.py to V3 schema (#9497) 2025-08-30 23:20:17 -04:00
fea9ea8268 convert Video nodes to V3 schema (#9489) 2025-08-30 23:19:54 -04:00
f949094b3c convert Stable Cascade nodes to V3 schema (#9373) 2025-08-30 23:19:21 -04:00
4449e14769 ComfyUI version 0.3.56 2025-08-30 06:31:19 -04:00
885015eecf Lower ram usage on windows. (#9628) 2025-08-29 23:06:04 -04:00
a86aaa4301 ComfyUI v0.3.55 2025-08-29 06:03:41 -04:00
2efb2cbc38 Update template to 0.1.70 (#9620) 2025-08-29 06:03:25 -04:00
15aa9222c4 Trim audio to video when saving video. (#9617) 2025-08-29 04:12:00 -04:00
c7bb3e2bce Support the 5B fun inpaint model. (#9614)
Use the WanFunInpaintToVideo node without the clip_vision_output.
2025-08-28 22:46:57 -04:00
e80a14ad50 Support wan2.2 5B fun control model. (#9611)
Use the Wan22FunControlToVideo node.
2025-08-28 22:13:07 -04:00
d28b39d93d Add a LatentCut node to cut latents. (#9609) 2025-08-28 19:38:28 -04:00
1c184c29eb Fix issue with s2v node when extending past audio length. (#9608) 2025-08-28 18:34:01 -04:00
edde0b5043 WanSoundImageToVideoExtend node to manually extend s2v video. (#9606) 2025-08-28 17:59:48 -04:00
0063610177 ComfyUI version 0.3.54 2025-08-28 10:44:57 -04:00
ce0052c087 Fix diffsynth controlnet regression. (#9597) 2025-08-28 10:37:42 -04:00
0eb821a7b6 ComfyUI 0.3.53 2025-08-27 23:09:06 -04:00
4aa79dbf2c Adjust flux mem usage factor a bit. (#9588) 2025-08-27 23:08:17 -04:00
38f697d953 Add a LatentConcat node. (#9587) 2025-08-27 22:28:10 -04:00
3aad339b63 Add DPM++ 2M SDE Heun (RES) sampler (#9542) 2025-08-27 19:07:31 -04:00
491755325c Better s2v memory estimation. (#9584) 2025-08-27 19:02:42 -04:00
496888fd68 Improve s2v performance when generating videos longer than 120 frames. (#9582) 2025-08-27 16:06:40 -04:00
b5ac6ed7ce Fixes to make controlnet type models work on qwen edit and kontext. (#9581) 2025-08-27 15:26:28 -04:00
b20ba1f27c Fix #9537 (#9576) 2025-08-27 12:45:02 -04:00
31a37686d0 Negative audio in s2v should be zeros. (#9578) 2025-08-27 12:44:29 -04:00
88aee596a3 WIP Wan 2.2 S2V model. (#9568) 2025-08-27 01:10:34 -04:00
6a193ac557 Update template to 0.1.68 (#9569)
* Update template to 0.1.67

* Update template to 0.1.68
2025-08-27 00:10:20 -04:00
47f4db3e84 Adding Google Gemini Image API node (#9566)
* bigcat88's progress on adding Google Gemini Image node

* Made Google Gemini Image node functional

* Bump frontend version to get static pricing badge on Gemini Image node
2025-08-26 22:20:44 -04:00
5352abc6d3 Update template to 0.1.66 (#9557) 2025-08-26 13:33:54 -04:00
39aa06bd5d Make AudioEncoderOutput usable in v3 node schema. (#9554) 2025-08-26 12:50:46 -04:00
914c2a2973 Implement wav2vec2 as an audio encoder model. (#9549)
This is useless on its own but there are multiple models that use it.
2025-08-25 23:26:47 -04:00
e633a47ad1 Add models/audio_encoders directory. (#9548) 2025-08-25 20:13:54 -04:00
f6b93d41a0 Remove models from readme that are not fully implemented. (#9535)
Cosmos model implementations are currently missing the safety part so it is technically not fully implemented and should not be advertised as such.
2025-08-24 15:40:32 -04:00
95ac7794b7 Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling (#9528)
* Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling

* Fix missing LazyCache check_metadata method
Ensure LazyCache reset method resets all the tensor state values
2025-08-24 15:29:49 -04:00
53 changed files with 2650 additions and 870 deletions

30
.github/workflows/test-execution.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Execution Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r tests-unit/requirements.txt
- name: Run Execution Tests
run: |
python -m pytest tests/execution -v --skip-timing-checks

View File

@ -65,7 +65,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
@ -77,7 +76,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models

View File

@ -0,0 +1,42 @@
from .wav2vec2 import Wav2Vec2Model
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
import torchaudio
class AudioEncoderModel():
def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_audio(self, audio, sample_rate):
comfy.model_management.load_model_gpu(self.patcher)
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
out, all_layers = self.model(audio.to(self.load_device))
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
audio_encoder = AudioEncoderModel(None)
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
return audio_encoder

View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_masked
class LayerNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
class ConvFeatureEncoder(nn.Module):
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
])
def forward(self, x):
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x.transpose(1, 2)
class FeatureProjection(nn.Module):
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
super().__init__()
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.layer_norm(x)
x = self.projection(x)
return x
class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
super().__init__()
self.conv = nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)[:, :, :-1]
x = self.activation(x)
x = x.transpose(1, 2)
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
num_layers=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_layers)
])
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
def forward(self, x, mask=None):
x = x + self.pos_conv_embed(x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x, mask)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, mask=None):
assert (mask is None) # TODO?
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention_masked(q, k, v, self.num_heads)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.output_dense(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
x = self.attention(x, mask=mask)
x = residual + x
x = x + self.feed_forward(self.final_layer_norm(x))
return x
class Wav2Vec2Model(nn.Module):
"""Complete Wav2Vec 2.0 model."""
def __init__(
self,
embed_dim=1024,
final_dim=256,
num_heads=16,
num_layers=24,
dtype=None, device=None, operations=None
):
super().__init__()
conv_dim = 512
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
self.encoder = TransformerEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
device=device, dtype=dtype, operations=operations
)
def forward(self, x, mask_time_indices=None, return_dict=False):
x = torch.mean(x, dim=1)
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
features = self.feature_extractor(x)
features = self.feature_projection(features)
batch_size, seq_len, _ = features.shape
x, all_x = self.encoder(features)
return x, all_x

View File

@ -143,6 +143,7 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")

View File

@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
return torch.expm1(h)
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
return (torch.expm1(h) - h) / h
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@ -853,6 +863,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x
@torch.no_grad()
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@ -925,6 +940,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@ -1535,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@ -1549,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
fac = 1 / (2 * r)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
sigma_s_1 = sigma_fn(lambda_s_1)
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@ -1609,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 2
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
if inject_noise:
segment_factor = (r_1 - r_2) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
# Step 3
b3 = ei_h_phi_2(-h_eta) / r_2
b1 = ei_h_phi_1(-h_eta) - b3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
if inject_noise:
segment_factor = (r_2 - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x

View File

@ -158,7 +158,7 @@ class Flux(nn.Module):
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img[:, :add.shape[1]] += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
@ -189,7 +189,7 @@ class Flux(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
img = img[:, txt.shape[1] :, ...]
@ -233,12 +233,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
if index_ref_method:
if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uso":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w
h += ref.shape[-2]
w += ref.shape[-1]
else:
index = 1
h_offset = 0

View File

@ -459,7 +459,7 @@ class QwenImageTransformer2DModel(nn.Module):
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states += add
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -4,7 +4,7 @@ import math
import torch
import torch.nn as nn
from einops import repeat
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
@ -153,7 +153,10 @@ def repeat_e(e, x):
repeats = x.size(1) // e.size(1)
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
if repeats * e.size(1) == x.size(1):
return torch.repeat_interleave(e, repeats, dim=1)
else:
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
class WanAttentionBlock(nn.Module):
@ -573,6 +576,28 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return freqs
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
@ -584,26 +609,16 @@ class WanModel(torch.nn.Module):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
t_len = x.shape[2]
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@ -839,3 +854,468 @@ class CameraWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class CausalConv1d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size=3,
stride=1,
dilation=1,
pad_mode='replicate',
operations=None,
**kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = operations.Conv1d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
need_global=True,
dtype=None,
device=None,
operations=None,):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
if need_global:
self.conv1_global = CausalConv1d(
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
self.norm1 = operations.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
if need_global:
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = operations.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm2 = operations.LayerNorm(
hidden_dim // 2,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm3 = operations.LayerNorm(
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
class CausalAudioEncoder(nn.Module):
def __init__(self,
dim=5120,
num_layers=25,
out_dim=2048,
video_rate=8,
num_token=4,
need_global=False,
dtype=None,
device=None,
operations=None):
super().__init__()
self.encoder = MotionEncoder_tc(
in_dim=dim,
hidden_dim=out_dim,
num_heads=num_token,
need_global=need_global, dtype=dtype, device=device, operations=operations)
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
# features B * num_layers * dim * video_length
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(
dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
output_dim = output_dim or embedding_dim * 2
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
def forward(self, x, temb):
temb = self.linear(self.silu(temb))
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
x = self.norm(x) * (1 + scale) + shift
return x
class AudioInjector_WAN(nn.Module):
def __init__(self,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
root_net=None,
enable_adain=False,
adain_dim=2048,
adain_mode=None,
dtype=None,
device=None,
operations=None):
super().__init__()
self.enable_adain = enable_adain
self.adain_mode = adain_mode
self.injected_block_id = {}
audio_injector_id = 0
for inject_id in inject_layer:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([
WanT2VCrossAttention(
dim=dim,
num_heads=num_heads,
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_feat = nn.ModuleList([
operations.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6, dtype=dtype, device=device
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_vec = nn.ModuleList([
operations.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6, dtype=dtype, device=device
) for _ in range(audio_injector_id)
])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([
AdaLayerNorm(
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
for _ in range(audio_injector_id)
])
if adain_mode != "attn_norm":
self.injector_adain_output_layers = nn.ModuleList(
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
audio_attn_id = self.injected_block_id.get(block_id, None)
if audio_attn_id is None:
return x
num_frames = audio_emb.shape[1]
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
if self.enable_adain and self.adain_mode == "attn_norm":
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
attn_hidden_states = adain_hidden_states
else:
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
residual_out = rearrange(
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out
return x
class FramePackMotioner(nn.Module):
def __init__(
self,
inner_dim=1024,
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
zip_frame_buckets=[
1, 2, 16
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
dtype=None,
device=None,
operations=None):
super().__init__()
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
self.zip_frame_buckets = zip_frame_buckets
self.inner_dim = inner_dim
self.num_heads = num_heads
self.drop_mode = drop_mode
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
if overlap_frame > 0:
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
padd_lat[:, :, -zero_end_frame:] = 0
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x)
l_2x_shape = clean_latents_2x.shape
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x)
l_4x_shape = clean_latents_4x.shape
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :
0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :
0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
return motion_lat, rope
class WanModel_S2V(WanModel):
def __init__(self,
model_type='s2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
audio_dim=1024,
num_audio_token=4,
enable_adain=True,
cond_dim=16,
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
adain_mode="attn_norm",
framepack_drop_mode="padd",
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
self.casual_audio_encoder = CausalAudioEncoder(
dim=audio_dim,
out_dim=self.dim,
num_token=num_audio_token,
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
if cond_dim > 0:
self.cond_encoder = operations.Conv3d(
cond_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size, device=device, dtype=dtype)
self.audio_injector = AudioInjector_WAN(
dim=self.dim,
num_heads=self.num_heads,
inject_layer=audio_inject_layers,
root_net=self,
enable_adain=enable_adain,
adain_dim=self.dim,
adain_mode=adain_mode,
dtype=dtype, device=device, operations=operations
)
self.frame_packer = FramePackMotioner(
inner_dim=self.dim,
num_heads=self.num_heads,
zip_frame_buckets=[1, 2, 16],
drop_mode=framepack_drop_mode,
dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
x,
t,
context,
audio_embed=None,
reference_latent=None,
control_video=None,
reference_motion=None,
clip_fea=None,
freqs=None,
transformer_options={},
**kwargs,
):
if audio_embed is not None:
num_embeds = x.shape[-3] * 4
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
else:
audio_emb = None
# embeddings
bs, _, time, height, width = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
if control_video is not None:
x = x + self.cond_encoder(control_video)
if t.ndim == 1:
t = t.unsqueeze(1).repeat(1, x.shape[2])
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
seq_len = x.size(1)
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
x = x + cond_mask_weight[0]
if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
ref = ref + cond_mask_weight[1]
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
del ref, freqs_ref
if reference_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
motion_encoded = motion_encoded + cond_mask_weight[2]
x = torch.cat([x, motion_encoded], dim=1)
freqs = torch.cat([freqs, freqs_motion], dim=1)
t = torch.repeat_interleave(t, 2, dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
del motion_encoded, freqs_motion
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:

View File

@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_uso_lora(sd):
sd_out = {}
for k in sd:
tensor = sd[k]
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
.replace(".up.weight", ".lora_up.weight")
.replace(".qkv_lora2.", ".txt_attn.qkv.")
.replace(".qkv_lora1.", ".img_attn.qkv.")
.replace(".proj_lora1.", ".img_attn.proj.")
.replace(".proj_lora2.", ".txt_attn.proj.")
.replace(".qkv_lora.", ".linear1_qkv.")
.replace(".proj_lora.", ".linear2.")
.replace(".processor.", ".")
)
sd_out[k_to] = tensor
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
return convert_uso_lora(sd)
return sd

View File

@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
self.memory_usage_shape_process = {}
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
if shape is not None:
if c in self.memory_usage_shape_process:
out = []
for s in shape:
out.append(self.memory_usage_shape_process[c](s))
shape = out
if len(shape) > 0:
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@ -1102,9 +1110,10 @@ class WAN21(BaseModel):
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
@ -1201,18 +1210,50 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
control_video = kwargs.get("control_video", None)
if control_video is not None:
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = reference_motion.shape
return out
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
denoise_mask = kwargs.get("denoise_mask", None)
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out

View File

@ -368,6 +368,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "s2v"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

2
comfy/samplers.py Normal file → Executable file
View File

@ -729,7 +729,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]

View File

@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.8
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@ -1072,6 +1072,19 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN22_S2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "s2v",
}
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_S2V(self, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1272,6 +1285,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
(mat1, mat2, alpha, None, None, None, None, None, None)
)
def to_train(self):
return LokrDiff(self.weights)
@classmethod
def load(
cls,

View File

@ -8,6 +8,7 @@ import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
@ -282,8 +283,6 @@ class VideoFromComponents(VideoInput):
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
@ -298,27 +297,12 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))

View File

@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO):
class MODEL_PATCH(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER")
class AudioEncoder(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1584,6 +1592,7 @@ class _IO:
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from typing import List, Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@ -4,11 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
"""
from __future__ import annotations
import json
import time
import os
import uuid
import base64
from io import BytesIO
from enum import Enum
from typing import Optional, Literal
@ -25,6 +26,7 @@ from comfy_api_nodes.apis import (
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
@ -35,6 +37,7 @@ from comfy_api_nodes.apinode_utils import (
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
bytesio_to_image_tensor,
)
@ -53,6 +56,14 @@ class GeminiModel(str, Enum):
gemini_2_5_flash = "gemini-2.5-flash"
class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
@ -75,6 +86,135 @@ def get_gemini_endpoint(
)
def get_gemini_image_endpoint(
model: GeminiImageModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiImageModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiImageGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def get_parts_from_response(
response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1,1024,1024,4))
return torch.cat(image_tensors, dim=0)
class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
@ -159,59 +299,6 @@ class GeminiNode(ComfyNodeABC):
CATEGORY = "api node/text/Gemini"
API_NODE = True
def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
@ -271,43 +358,6 @@ class GeminiNode(ComfyNodeABC):
)
return audio_parts
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
async def api_call(
self,
prompt: str,
@ -323,11 +373,11 @@ class GeminiNode(ComfyNodeABC):
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
image_parts = create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
@ -351,7 +401,7 @@ class GeminiNode(ComfyNodeABC):
).execute()
# Get result output
output_text = self.get_text_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
@ -462,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC):
return (files,)
class GeminiImage(ComfyNodeABC):
"""
Node to generate text and image responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, files) to generate coherent
text and image responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for generation",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
# TODO: later we can add this parameter later
# "n": (
# IO.INT,
# {
# "default": 1,
# "min": 1,
# "max": 8,
# "step": 1,
# "display": "number",
# "tooltip": "How many images to generate",
# },
# ),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE, IO.STRING)
FUNCTION = "api_call"
CATEGORY = "api node/image/Gemini"
DESCRIPTION = "Edit images synchronously via Google API."
API_NODE = True
async def api_call(
self,
prompt: str,
model: GeminiImageModel,
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if files is not None:
parts.extend(files)
response = await SynchronousOperation(
endpoint=get_gemini_image_endpoint(model),
request=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
)
),
auth_kwargs=kwargs,
).execute()
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
output_text = output_text or "Empty response from Gemini model..."
return (output_image, output_text,)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiImageNode": GeminiImage,
"GeminiInputFiles": GeminiInputFiles,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiImageNode": "Google Gemini Image",
"GeminiInputFiles": "Gemini Input Files",
}

View File

@ -1,5 +1,8 @@
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
@ -46,87 +49,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode:
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue."
},
"would convey a sky that was blue and green, but more green than blue.",
),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
tooltip="Aspect ratio of generated image.",
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
},
comfy_io.Image.Input(
"image",
optional=True,
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
comfy_io.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@ -144,6 +154,11 @@ class StabilityStableImageUltraNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
@ -161,7 +176,7 @@ class StabilityStableImageUltraNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -171,95 +186,106 @@ class StabilityStableImageUltraNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node:
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
max=10.0,
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
comfy_io.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"model": ([x.value for x in Stability_SD3_5_Model],),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"cfg_scale": (
IO.FLOAT,
{
"default": 4.0,
"min": 1.0,
"max": 10.0,
"step": 0.1,
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
cfg_scale: float,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@ -280,6 +306,11 @@ class StabilityStableImageSD_3_5Node:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
@ -300,7 +331,7 @@ class StabilityStableImageSD_3_5Node:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -310,72 +341,75 @@ class StabilityStableImageSD_3_5Node:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode:
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
"creativity",
default=0.35,
min=0.2,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.35,
"min": 0.2,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
**kwargs):
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -386,6 +420,11 @@ class StabilityUpscaleConservativeNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
@ -401,7 +440,7 @@ class StabilityUpscaleConservativeNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -411,77 +450,81 @@ class StabilityUpscaleConservativeNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode:
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
"creativity",
default=0.3,
min=0.1,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.3,
"min": 0.1,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
**kwargs):
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -494,6 +537,11 @@ class StabilityUpscaleCreativeNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
@ -510,7 +558,7 @@ class StabilityUpscaleCreativeNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -525,7 +573,8 @@ class StabilityUpscaleCreativeNode:
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=kwargs,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
@ -535,41 +584,48 @@ class StabilityUpscaleCreativeNode:
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleFastNode:
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, **kwargs):
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
@ -580,7 +636,7 @@ class StabilityUpscaleFastNode:
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -590,24 +646,20 @@ class StabilityUpscaleFastNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
}
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
}
async def comfy_entrypoint() -> StabilityExtension:
return StabilityExtension()

View File

@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler):
}
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
)
@override

View File

@ -0,0 +1,44 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
class AudioEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,)
class AudioEncoderEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,)
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}

View File

@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
input_change = None
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.verbose:
@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
@ -194,6 +196,7 @@ class EasyCacheHolder:
# how to deal with mismatched dims
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
@ -283,6 +286,17 @@ class EasyCacheHolder:
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape[1:])
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
@ -299,6 +313,7 @@ class EasyCacheHolder:
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
@ -360,6 +375,7 @@ class LazyCacheHolder:
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None
def has_cache_diff(self) -> bool:
return self.cache_diff is not None
@ -404,6 +420,17 @@ class LazyCacheHolder:
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
self.cache_diff = output - x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
@ -412,7 +439,14 @@ class LazyCacheHolder:
self.approx_output_change_rates = []
del self.cache_diff
self.cache_diff = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):

View File

@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
"reference_latents_method": (("offset", "index", "uso"), ),
}}
RETURN_TYPES = ("CONDITIONING",)

View File

@ -1,6 +1,7 @@
import comfy.utils
import comfy_extras.nodes_post_processing
import torch
import nodes
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -105,6 +106,73 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
class LatentConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
if "-" in dim:
c = (s2, s1)
else:
c = (s1, s2)
if "x" in dim:
dim = -1
elif "y" in dim:
dim = -2
elif "t" in dim:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
class LatentCut:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT",),
"dim": (["x", "y", "t"], ),
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if index >= 0:
index = min(index, s1.shape[dim] - 1)
amount = min(s1.shape[dim] - index, amount)
else:
index = max(index, -s1.shape[dim])
amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
@ -279,6 +347,8 @@ NODE_CLASS_MAPPINGS = {
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentCut": LatentCut,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,

View File

@ -89,6 +89,7 @@ class DiffSynthCnetPatch:
self.strength = strength
self.mask = mask
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
self.encoded_image_size = (image.shape[1], image.shape[2])
def encode_latent_cond(self, image):
latent_image = self.vae.encode(image)
@ -106,14 +107,15 @@ class DiffSynthCnetPatch:
x = kwargs.get("x")
img = kwargs.get("img")
block_index = kwargs.get("block_index")
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
spacial_compression = self.vae.spacial_compression_encode()
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models)
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img
return kwargs

View File

@ -1,98 +1,109 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
import sys
from typing_extensions import override
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
from comfy_api.latest import ComfyExtension, io
class String(ComfyNodeABC):
class String(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveString",
display_name="String",
category="utils/primitive",
inputs=[
io.String.Input("value"),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class StringMultiline(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {"multiline": True,},)},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Int(ComfyNodeABC):
class StringMultiline(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveStringMultiline",
display_name="String (Multiline)",
category="utils/primitive",
inputs=[
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(ComfyNodeABC):
class Int(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveInt",
display_name="Int",
category="utils/primitive",
inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
],
outputs=[io.Int.Output()],
)
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
@classmethod
def execute(cls, value: int) -> io.NodeOutput:
return io.NodeOutput(value)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveStringMultiline": StringMultiline,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
class Float(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveFloat",
display_name="Float",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
],
outputs=[io.Float.Output()],
)
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveStringMultiline": "String (Multiline)",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}
@classmethod
def execute(cls, value: float) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveBoolean",
display_name="Boolean",
category="utils/primitive",
inputs=[
io.Boolean.Input("value"),
],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, value: bool) -> io.NodeOutput:
return io.NodeOutput(value)
class PrimitivesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
String,
StringMultiline,
Int,
Float,
Boolean,
]
async def comfy_entrypoint() -> PrimitivesExtension:
return PrimitivesExtension()

View File

@ -17,55 +17,61 @@
"""
import torch
import nodes
from typing_extensions import override
import comfy.utils
import nodes
from comfy_api.latest import ComfyExtension, io
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_EmptyLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_EmptyLatentImage",
category="latent/stable_cascade",
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
def execute(cls, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_StageC_VAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageC_VAEEncode",
category="latent/stable_cascade",
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
def execute(cls, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
class StableCascade_StageB_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageB_Conditioning",
category="conditioning/stable_cascade",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(),
],
)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
@classmethod
def execute(cls, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d]
c.append(n)
return (c, )
return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(display_name="controlnet_input"),
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
def execute(cls, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
return io.NodeOutput(controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}
class StableCascadeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StableCascade_EmptyLatentImage,
StableCascade_StageB_Conditioning,
StableCascade_StageC_VAEEncode,
StableCascade_SuperResolutionControlnet,
]
async def comfy_entrypoint() -> StableCascadeExtension:
return StableCascadeExtension()

View File

@ -5,52 +5,49 @@ import av
import torch
import folder_paths
import json
from typing import Optional, Literal
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.latest import Input, InputImpl, Types
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
class SaveWEBM:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveWEBM(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
category="image/video",
is_experimental=True,
inputs=[
io.Image.Input("images"),
io.String.Input("filename_prefix", default="ComfyUI"),
io.Combo.Input("codec", options=["vp9", "av1"]),
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"codec": (["vp9", "av1"],),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/video"
EXPERIMENTAL = True
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
file = f"{filename}_{counter:05}_.webm"
container = av.open(os.path.join(full_output_folder, file), mode="w")
if prompt is not None:
container.metadata["prompt"] = json.dumps(prompt)
if cls.hidden.prompt is not None:
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
@ -69,63 +66,46 @@ class SaveWEBM:
container.mux(stream.encode())
container.close()
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type
}]
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
class SaveVideo(ComfyNodeABC):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type: Literal["output"] = "output"
self.prefix_append = ""
class SaveVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "image/video"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
self.output_dir,
folder_paths.get_output_directory(),
width,
height
)
results: list[FileLocator] = list()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if extra_pnginfo is not None:
metadata.update(extra_pnginfo)
if prompt is not None:
metadata["prompt"] = prompt
if cls.hidden.extra_pnginfo is not None:
metadata.update(cls.hidden.extra_pnginfo)
if cls.hidden.prompt is not None:
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
metadata=saved_metadata
)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return { "ui": { "images": results, "animated": (True,) } }
class CreateVideo(ComfyNodeABC):
class CreateVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
},
"optional": {
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
}
}
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
display_name="Create Video",
category="image/video",
description="Create a video from images.",
inputs=[
io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[
io.Video.Output(),
],
)
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "create_video"
CATEGORY = "image/video"
DESCRIPTION = "Create a video from images."
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
return (InputImpl.VideoFromComponents(
Types.VideoComponents(
images=images,
audio=audio,
frame_rate=Fraction(fps),
)
),)
class GetVideoComponents(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
}
}
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
RETURN_NAMES = ("images", "audio", "fps")
FUNCTION = "get_components"
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
CATEGORY = "image/video"
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
class GetVideoComponents(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
],
outputs=[
io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"),
],
)
def get_components(self, video: Input.Video):
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
components = video.get_components()
return (components.images, components.audio, float(components.frame_rate))
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(ComfyNodeABC):
class LoadVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["video"])
return {"required":
{"file": (sorted(files), {"video_upload": True})},
}
CATEGORY = "image/video"
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "load_video"
def load_video(self, file):
video_path = folder_paths.get_annotated_filepath(file)
return (InputImpl.VideoFromFile(video_path),)
return io.Schema(
node_id="LoadVideo",
display_name="Load Video",
category="image/video",
inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def IS_CHANGED(cls, file):
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
video_path = folder_paths.get_annotated_filepath(file)
mod_time = os.path.getmtime(video_path)
# Instead of hashing the file, we can just use the modification time to avoid
@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
return mod_time
@classmethod
def VALIDATE_INPUTS(cls, file):
def validate_inputs(s, file):
if not folder_paths.exists_annotated_filepath(file):
return "Invalid video file: {}".format(file)
return True
NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM,
"SaveVideo": SaveVideo,
"CreateVideo": CreateVideo,
"GetVideoComponents": GetVideoComponents,
"LoadVideo": LoadVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveVideo": "Save Video",
"CreateVideo": "Create Video",
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
}
class VideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SaveWEBM,
SaveVideo,
CreateVideo,
GetVideoComponents,
LoadVideo,
]
async def comfy_entrypoint() -> VideoExtension:
return VideoExtension()

View File

@ -139,16 +139,21 @@ class Wan22FunControlToVideo(io.ComfyNode):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
spacial_scale = vae.spacial_compression_encode()
latent_channels = vae.latent_channels
latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
if latent_channels == 48:
concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent)
else:
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask[:, :, :start_image.shape[0] + 3] = 0.0
ref_latent = None
@ -159,11 +164,11 @@ class Wan22FunControlToVideo(io.ComfyNode):
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
@ -201,7 +206,8 @@ class WanFirstLastFrameToVideo(io.ComfyNode):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
spacial_scale = vae.spacial_compression_encode()
latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
@ -786,6 +792,229 @@ class WanTrackToVideo(io.ComfyNode):
return io.NodeOutput(positive, negative, out_latent)
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2) # [1, 512, T]
seq_len = features.shape[2] / float(input_fps) # T/f_a
if output_len is None:
output_len = int(seq_len * output_fps) # f_m*T/f_a
output_features = torch.nn.functional.interpolate(
features, size=output_len, align_corners=True,
mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2) # [1, output_len, 512]
def get_sample_indices(original_fps,
total_frames,
target_fps,
num_sample,
fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=video_rate,
total_frames=audio_frame_num + padd_audio_num,
target_fps=fps,
num_sample=bucket_num,
fixed_start=0)
batch_audio_eb = []
audio_sample_stride = int(video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(
range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
latent_t = ((length - 1) // 4) + 1
if audio_encoder_output is not None:
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
video_rate = 30
fps = 16
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
batch_frames = latent_t * 4
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
if audio_embed_bucket.shape[3] > 0:
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
frame_offset += batch_frames
if ref_image is not None:
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(ref_image[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if ref_motion is not None:
if ref_motion.shape[0] > 73:
ref_motion = ref_motion[-73:]
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if ref_motion.shape[0] < 73:
r = torch.ones([73, height, width, 3]) * 0.5
r[-ref_motion.shape[0]:] = ref_motion
ref_motion = r
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
if ref_motion_latent is not None:
ref_motion_latent = ref_motion_latent[:, :, -19:]
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
control_video = vae.encode(control_video[:, :, :, :3])
control_video_out[:, :, :control_video.shape[2]] = control_video
# TODO: check if zero is better than none if none provided
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
out_latent = {}
out_latent["samples"] = latent
return positive, negative, out_latent, frame_offset
class WanSoundImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSoundImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
io.Image.Input("ref_motion", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
control_video=control_video, ref_motion=ref_motion)
return io.NodeOutput(positive, negative, out_latent)
class WanSoundImageToVideoExtend(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSoundImageToVideoExtend",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Latent.Input("video_latent"),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
video_latent = video_latent["samples"]
width = video_latent.shape[-1] * 8
height = video_latent.shape[-2] * 8
batch_size = video_latent.shape[0]
frame_offset = video_latent.shape[-3] * 4
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
return io.NodeOutput(positive, negative, out_latent)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -844,6 +1073,8 @@ class WanExtension(ComfyExtension):
TrimVideoLatent,
WanCameraImageToVideo,
WanPhantomSubjectToVideo,
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
Wan22ImageToVideoLatent,
]

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.52"
__version__ = "0.3.56"

View File

@ -48,6 +48,8 @@ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers"
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

View File

@ -112,7 +112,7 @@ import gc
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
if __name__ == "__main__":
if args.default_device is not None:

View File

@ -2324,6 +2324,7 @@ async def init_builtin_extra_nodes():
"nodes_qwen.py",
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.52"
version = "0.3.56"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.10
comfyui-workflow-templates==0.1.65
comfyui-frontend-package==1.25.11
comfyui-workflow-templates==0.1.70
comfyui-embedded-docs==0.2.6
torch
torchsde

View File

@ -6,6 +6,7 @@ def pytest_addoption(parser):
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)")
# This initializes args at the beginning of the test session
@pytest.fixture(scope="session", autouse=True)
@ -19,6 +20,11 @@ def args_pytest(pytestconfig):
return args
@pytest.fixture(scope="session")
def skip_timing_checks(pytestconfig):
"""Fixture that returns whether timing checks should be skipped."""
return pytestconfig.getoption("--skip-timing-checks")
def pytest_collection_modifyitems(items):
# Modifies items so tests run in the correct order

View File

@ -7,7 +7,7 @@ import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient, run_warmup
from tests.execution.test_execution import ComfyClient, run_warmup
@pytest.mark.execution
@ -23,7 +23,7 @@ class TestAsyncNodes:
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
@ -81,7 +81,7 @@ class TestAsyncNodes:
assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -104,7 +104,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
if not skip_timing_checks:
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
# Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
@ -150,7 +151,7 @@ class TestAsyncNodes:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
@ -173,7 +174,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
if not skip_timing_checks:
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
@ -310,7 +312,7 @@ class TestAsyncNodes:
images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
@ -330,9 +332,10 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
if not skip_timing_checks:
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
@ -345,8 +348,8 @@ class TestAsyncNodes:
dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0),
image2=image2.out(0),
num_async_nodes=3,
sleep_duration=0.2)
num_async_nodes=5,
sleep_duration=0.4)
g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time()
@ -354,7 +357,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
if not skip_timing_checks:
assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):

View File

@ -149,7 +149,7 @@ class TestExecution:
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
@ -518,7 +518,7 @@ class TestExecution:
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -541,14 +541,15 @@ class TestExecution:
# The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
if not skip_timing_checks:
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -574,7 +575,9 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Lots of leeway here since Windows CI is slow
if not skip_timing_checks:
assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"

View File

@ -0,0 +1,233 @@
"""Test that progress updates are properly isolated between WebSocket clients."""
import json
import pytest
import time
import threading
import uuid
import websocket
from typing import List, Dict, Any
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient
class ProgressTracker:
"""Tracks progress messages received by a WebSocket client."""
def __init__(self, client_id: str):
self.client_id = client_id
self.progress_messages: List[Dict[str, Any]] = []
self.lock = threading.Lock()
def add_message(self, message: Dict[str, Any]):
"""Thread-safe addition of progress messages."""
with self.lock:
self.progress_messages.append(message)
def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
"""Get all progress messages for a specific prompt_id."""
with self.lock:
return [
msg for msg in self.progress_messages
if msg.get('data', {}).get('prompt_id') == prompt_id
]
def has_cross_contamination(self, own_prompt_id: str) -> bool:
"""Check if this client received progress for other prompts."""
with self.lock:
for msg in self.progress_messages:
msg_prompt_id = msg.get('data', {}).get('prompt_id')
if msg_prompt_id and msg_prompt_id != own_prompt_id:
return True
return False
class IsolatedClient(ComfyClient):
"""Extended ComfyClient that tracks all WebSocket messages."""
def __init__(self):
super().__init__()
self.progress_tracker = None
self.all_messages: List[Dict[str, Any]] = []
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
"""Connect with a specific client_id and set up message tracking."""
if client_id is None:
client_id = str(uuid.uuid4())
super().connect(listen, port, client_id)
self.progress_tracker = ProgressTracker(client_id)
def listen_for_messages(self, duration: float = 5.0):
"""Listen for WebSocket messages for a specified duration."""
end_time = time.time() + duration
self.ws.settimeout(0.5) # Non-blocking with timeout
while time.time() < end_time:
try:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
self.all_messages.append(message)
# Track progress_state messages
if message.get('type') == 'progress_state':
self.progress_tracker.add_message(message)
except websocket.WebSocketTimeoutException:
continue
except Exception:
# Log error silently in test context
break
@pytest.mark.execution
class TestProgressIsolation:
"""Test suite for verifying progress update isolation between clients."""
@pytest.fixture(scope="class", autouse=True)
def _server(self, args_pytest):
"""Start the ComfyUI server for testing."""
import subprocess
pargs = [
'python', 'main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
p = subprocess.Popen(pargs)
yield
p.kill()
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
"""Start client with connection retries."""
client = IsolatedClient()
# Connect to server (with retries)
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen, port, client_id)
return client
except ConnectionRefusedError as e:
print(e) # noqa: T201
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
def test_progress_isolation_between_clients(self, args_pytest):
"""Test that progress updates are isolated between different clients."""
listen = args_pytest["listen"]
port = args_pytest["port"]
# Create two separate clients with unique IDs
client_a_id = "client_a_" + str(uuid.uuid4())
client_b_id = "client_b_" + str(uuid.uuid4())
try:
# Connect both clients with retries
client_a = self.start_client_with_retry(listen, port, client_a_id)
client_b = self.start_client_with_retry(listen, port, client_b_id)
# Create simple workflows for both clients
graph_a = GraphBuilder(prefix="client_a")
image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
graph_a.node("PreviewImage", images=image_a.out(0))
graph_b = GraphBuilder(prefix="client_b")
image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
graph_b.node("PreviewImage", images=image_b.out(0))
# Submit workflows from both clients
prompt_a = graph_a.finalize()
prompt_b = graph_b.finalize()
response_a = client_a.queue_prompt(prompt_a)
prompt_id_a = response_a['prompt_id']
response_b = client_b.queue_prompt(prompt_b)
prompt_id_b = response_b['prompt_id']
# Start threads to listen for messages on both clients
def listen_client_a():
client_a.listen_for_messages(duration=10.0)
def listen_client_b():
client_b.listen_for_messages(duration=10.0)
thread_a = threading.Thread(target=listen_client_a)
thread_b = threading.Thread(target=listen_client_b)
thread_a.start()
thread_b.start()
# Wait for threads to complete
thread_a.join()
thread_b.join()
# Verify isolation
# Client A should only receive progress for prompt_id_a
assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
f"Client A received progress updates for other clients' workflows. " \
f"Expected only {prompt_id_a}, but got messages for multiple prompts."
# Client B should only receive progress for prompt_id_b
assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
f"Client B received progress updates for other clients' workflows. " \
f"Expected only {prompt_id_b}, but got messages for multiple prompts."
# Verify each client received their own progress updates
client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)
assert len(client_a_messages) > 0, \
"Client A did not receive any progress updates for its own workflow"
assert len(client_b_messages) > 0, \
"Client B did not receive any progress updates for its own workflow"
# Ensure no cross-contamination
client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)
assert len(client_a_other) == 0, \
f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
assert len(client_b_other) == 0, \
f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"
finally:
# Clean up connections
if hasattr(client_a, 'ws'):
client_a.ws.close()
if hasattr(client_b, 'ws'):
client_b.ws.close()
def test_progress_with_missing_client_id(self, args_pytest):
"""Test that progress updates handle missing client_id gracefully."""
listen = args_pytest["listen"]
port = args_pytest["port"]
try:
# Connect client with retries
client = self.start_client_with_retry(listen, port)
# Create a simple workflow
graph = GraphBuilder(prefix="test_missing_id")
image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
graph.node("PreviewImage", images=image.out(0))
# Submit workflow
prompt = graph.finalize()
response = client.queue_prompt(prompt)
prompt_id = response['prompt_id']
# Listen for messages
client.listen_for_messages(duration=5.0)
# Should still receive progress updates for own workflow
messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
assert len(messages) > 0, \
"Client did not receive progress updates even though it initiated the workflow"
finally:
if hasattr(client, 'ws'):
client.ws.close()