mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:56:18 +08:00
Compare commits
1 Commits
pysssss/ba
...
cb/video-s
| Author | SHA1 | Date | |
|---|---|---|---|
| e4f3d335dc |
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@ -25,10 +25,6 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libx11-dev
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@ -215,9 +215,22 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -227,102 +240,144 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
# audio to video cross attention
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@ -534,20 +589,9 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@ -574,9 +618,8 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@ -619,11 +662,10 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
@ -260,7 +260,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
from .video_types import VideoInput, VideoOp, SliceOp
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"VideoOp",
|
||||
"SliceOp",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@ -1,11 +1,48 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import copy
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class VideoOp(ABC):
|
||||
"""Base class for lazy video operations."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceOp(VideoOp):
|
||||
"""Extract a range of frames from the video."""
|
||||
start_frame: int
|
||||
frame_count: int
|
||||
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
total = components.images.shape[0]
|
||||
start = max(0, min(self.start_frame, total))
|
||||
end = min(start + self.frame_count, total)
|
||||
return VideoComponents(
|
||||
images=components.images[start:end],
|
||||
audio=components.audio,
|
||||
frame_rate=components.frame_rate,
|
||||
metadata=getattr(components, 'metadata', None),
|
||||
)
|
||||
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
start = max(0, min(self.start_frame, input_frame_count))
|
||||
return min(self.frame_count, input_frame_count - start)
|
||||
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
@ -21,6 +58,12 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
|
||||
"""Return a copy of this video with a slice operation appended."""
|
||||
new = copy.copy(self)
|
||||
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
|
||||
return new
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
from .._input import SliceOp
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
"SliceOp",
|
||||
]
|
||||
|
||||
@ -3,7 +3,7 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput, VideoOp
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self._operations: list[VideoOp] = []
|
||||
self.__materialized: Optional[VideoFromComponents] = None
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
|
||||
# Apply operations to get final frame count
|
||||
for op in self._operations:
|
||||
frame_count = op.compute_frame_count(frame_count)
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self.__materialized is not None:
|
||||
return self.__materialized.get_components()
|
||||
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
components = self.get_components_internal(container)
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__materialized = VideoFromComponents(components)
|
||||
self._operations = []
|
||||
return components
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
self._operations: list[VideoOp] = []
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self._operations:
|
||||
components = self.__components
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__components = components
|
||||
self._operations = []
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
count = int(self.__components.images.shape[0])
|
||||
for op in self._operations:
|
||||
count = op.compute_frame_count(count)
|
||||
return count
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
# Materialize ops before saving
|
||||
components = self.get_components()
|
||||
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.width = components.images.shape[2]
|
||||
video_stream.height = components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
if components.audio:
|
||||
audio_sample_rate = int(components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
for i, frame in enumerate(components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
if audio_stream and components.audio:
|
||||
waveform = components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
@ -1344,8 +1344,7 @@ class Schema:
|
||||
"""The category of the node, as per the "Add Node" menu."""
|
||||
inputs: list[Input] = field(default_factory=list)
|
||||
outputs: list[Output] = field(default_factory=list)
|
||||
hidden: list[Hidden | str] = field(default_factory=list)
|
||||
"""Hidden inputs. Use Hidden enum for system values (PROMPT, UNIQUE_ID, etc.) or plain strings for custom frontend-provided values."""
|
||||
hidden: list[Hidden] = field(default_factory=list)
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
search_aliases: list[str] = field(default_factory=list)
|
||||
@ -1444,10 +1443,7 @@ class Schema:
|
||||
input = create_input_dict_v1(self.inputs)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
if isinstance(hidden, str):
|
||||
input.setdefault("hidden", {})[hidden] = (hidden,)
|
||||
else:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
# create separate lists from output fields
|
||||
output = []
|
||||
output_is_list = []
|
||||
@ -1508,10 +1504,7 @@ class Schema:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
if isinstance(hidden, str):
|
||||
hidden_list.append(hidden)
|
||||
else:
|
||||
hidden_list.append(hidden.value)
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
|
||||
@ -28,7 +28,6 @@ class AlignYourStepsScheduler(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
search_aliases=["AYS scheduler"],
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
|
||||
@ -71,7 +71,6 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -69,7 +69,6 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEEncodeAudio",
|
||||
search_aliases=["audio to latent"],
|
||||
display_name="VAE Encode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
@ -98,7 +97,6 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudio",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
@ -124,7 +122,6 @@ class SaveAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -149,7 +146,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -177,7 +173,6 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
search_aliases=["export opus"],
|
||||
display_name="Save Audio (Opus)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -205,7 +200,6 @@ class PreviewAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewAudio",
|
||||
search_aliases=["play audio"],
|
||||
display_name="Preview Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -265,7 +259,6 @@ class LoadAudio(IO.ComfyNode):
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return IO.Schema(
|
||||
node_id="LoadAudio",
|
||||
search_aliases=["import audio", "open audio", "audio file"],
|
||||
display_name="Load Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -303,7 +296,6 @@ class RecordAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecordAudio",
|
||||
search_aliases=["microphone input", "audio capture", "voice input"],
|
||||
display_name="Record Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -328,7 +320,6 @@ class TrimAudioDuration(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TrimAudioDuration",
|
||||
search_aliases=["cut audio", "audio clip", "shorten audio"],
|
||||
display_name="Trim Audio Duration",
|
||||
description="Trim audio tensor into chosen time range.",
|
||||
category="audio",
|
||||
@ -381,7 +372,6 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SplitAudioChannels",
|
||||
search_aliases=["stereo to mono"],
|
||||
display_name="Split Audio Channels",
|
||||
description="Separates the audio into left and right channels.",
|
||||
category="audio",
|
||||
@ -482,7 +472,6 @@ class AudioConcat(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
@ -530,7 +519,6 @@ class AudioMerge(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
@ -591,7 +579,6 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -627,7 +614,6 @@ class EmptyAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAudio",
|
||||
search_aliases=["blank audio"],
|
||||
display_name="Empty Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
|
||||
@ -10,7 +10,6 @@ class Canny(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Canny",
|
||||
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
||||
category="image/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -109,7 +109,6 @@ class PorterDuffImageComposite(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PorterDuffImageComposite",
|
||||
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
|
||||
display_name="Porter-Duff Image Composite",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@ -166,7 +165,6 @@ class SplitImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitImageWithAlpha",
|
||||
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
|
||||
display_name="Split Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@ -190,7 +188,6 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JoinImageWithAlpha",
|
||||
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
|
||||
display_name="Join Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
|
||||
@ -38,7 +38,6 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
search_aliases=["masked controlnet"],
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@ -297,7 +297,6 @@ class ExtendIntermediateSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ExtendIntermediateSigmas",
|
||||
search_aliases=["interpolate sigmas"],
|
||||
category="sampling/custom_sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
@ -857,7 +856,6 @@ class DualCFGGuider(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DualCFGGuider",
|
||||
search_aliases=["dual prompt guidance"],
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -885,7 +883,6 @@ class DisableNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DisableNoise",
|
||||
search_aliases=["zero noise"],
|
||||
category="sampling/custom_sampling/noise",
|
||||
inputs=[],
|
||||
outputs=[io.Noise.Output()]
|
||||
@ -1022,7 +1019,6 @@ class ManualSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="_for_testing/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode):
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
|
||||
class SaveTrainingDataset(io.ComfyNode):
|
||||
"""Save encoded training dataset (latents + conditioning) to disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
|
||||
class LoadTrainingDataset(io.ComfyNode):
|
||||
"""Load encoded training dataset from disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
|
||||
@ -11,7 +11,6 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@ -58,7 +58,6 @@ class FreSca(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
|
||||
@ -1,439 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TypedDict, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.cli_args import args
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import moderngl
|
||||
except ImportError as e:
|
||||
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
|
||||
|
||||
# Default NOOP fragment shader that passes through the input image unchanged
|
||||
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
|
||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
|
||||
in vec2 v_texcoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
fragColor0 = texture(u_image0, v_texcoord);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# Simple vertex shader for full-screen quad
|
||||
VERTEX_SHADER = """#version 330
|
||||
|
||||
in vec2 in_position;
|
||||
in vec2 in_texcoord;
|
||||
|
||||
out vec2 v_texcoord;
|
||||
|
||||
void main() {
|
||||
gl_Position = vec4(in_position, 0.0, 1.0);
|
||||
v_texcoord = in_texcoord;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop_glsl(source: str) -> str:
|
||||
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
|
||||
return re.sub(r'#version\s+300\s+es', '#version 330', source)
|
||||
|
||||
|
||||
def _create_software_gl_context() -> moderngl.Context:
|
||||
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
|
||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
||||
try:
|
||||
ctx = moderngl.create_standalone_context(require=330)
|
||||
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||
return ctx
|
||||
finally:
|
||||
if original_env is None:
|
||||
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
|
||||
else:
|
||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
|
||||
|
||||
|
||||
def _create_gl_context(force_software: bool = False) -> moderngl.Context:
|
||||
if force_software:
|
||||
try:
|
||||
return _create_software_gl_context()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to create software-rendered OpenGL context.\n"
|
||||
"Ensure Mesa/llvmpipe is installed for software rendering support."
|
||||
) from e
|
||||
|
||||
# Try hardware rendering first, fall back to software
|
||||
try:
|
||||
ctx = moderngl.create_standalone_context(require=330)
|
||||
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||
return ctx
|
||||
except Exception as hw_error:
|
||||
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
|
||||
logger.info("Attempting software rendering fallback...")
|
||||
try:
|
||||
return _create_software_gl_context()
|
||||
except Exception as sw_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n"
|
||||
f"Hardware error: {hw_error}\n\n"
|
||||
f"Possible solutions:\n"
|
||||
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
|
||||
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
|
||||
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
|
||||
) from sw_error
|
||||
|
||||
|
||||
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
|
||||
height, width = image.shape[:2]
|
||||
channels = image.shape[2] if len(image.shape) > 2 else 1
|
||||
|
||||
components = min(channels, 4)
|
||||
|
||||
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
||||
|
||||
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
|
||||
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
|
||||
|
||||
texture = ctx.texture((width, height), components, image_uint8.tobytes())
|
||||
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||
texture.repeat_x = False
|
||||
texture.repeat_y = False
|
||||
|
||||
return texture
|
||||
|
||||
|
||||
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
|
||||
width, height = fbo.size
|
||||
|
||||
data = fbo.read(components=channels, attachment=attachment)
|
||||
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
|
||||
|
||||
image = np.ascontiguousarray(np.flipud(image))
|
||||
|
||||
return image.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
|
||||
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
|
||||
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
|
||||
|
||||
try:
|
||||
program = ctx.program(
|
||||
vertex_shader=VERTEX_SHADER,
|
||||
fragment_shader=fragment_source,
|
||||
)
|
||||
return program
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Fragment shader compilation failed.\n\n"
|
||||
"Make sure your shader:\n"
|
||||
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
|
||||
"2. Has valid GLSL ES 3.00 syntax\n"
|
||||
"3. Includes 'precision highp float;' after version\n"
|
||||
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
|
||||
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
|
||||
) from e
|
||||
|
||||
|
||||
def _render_shader(
|
||||
ctx: moderngl.Context,
|
||||
program: moderngl.Program,
|
||||
width: int,
|
||||
height: int,
|
||||
textures: list[moderngl.Texture],
|
||||
uniforms: dict[str, int | float],
|
||||
) -> list[np.ndarray]:
|
||||
# Create output textures
|
||||
output_textures = []
|
||||
for _ in range(MAX_OUTPUTS):
|
||||
tex = ctx.texture((width, height), 4)
|
||||
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||
output_textures.append(tex)
|
||||
|
||||
fbo = ctx.framebuffer(color_attachments=output_textures)
|
||||
|
||||
# Full-screen quad vertices (position + texcoord)
|
||||
vertices = np.array([
|
||||
# Position (x, y), Texcoord (u, v)
|
||||
-1.0, -1.0, 0.0, 0.0,
|
||||
1.0, -1.0, 1.0, 0.0,
|
||||
-1.0, 1.0, 0.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0,
|
||||
], dtype='f4')
|
||||
|
||||
vbo = ctx.buffer(vertices.tobytes())
|
||||
vao = ctx.vertex_array(
|
||||
program,
|
||||
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
|
||||
)
|
||||
|
||||
try:
|
||||
# Bind textures
|
||||
for i, texture in enumerate(textures):
|
||||
texture.use(i)
|
||||
uniform_name = f'u_image{i}'
|
||||
if uniform_name in program:
|
||||
program[uniform_name].value = i
|
||||
|
||||
# Set uniforms
|
||||
if 'u_resolution' in program:
|
||||
program['u_resolution'].value = (float(width), float(height))
|
||||
|
||||
for name, value in uniforms.items():
|
||||
if name in program:
|
||||
program[name].value = value
|
||||
|
||||
# Render
|
||||
fbo.use()
|
||||
fbo.clear(0.0, 0.0, 0.0, 1.0)
|
||||
vao.render(moderngl.TRIANGLE_STRIP)
|
||||
|
||||
# Read results from all attachments
|
||||
results = []
|
||||
for i in range(MAX_OUTPUTS):
|
||||
results.append(_texture_to_image(fbo, attachment=i, channels=4))
|
||||
return results
|
||||
finally:
|
||||
vao.release()
|
||||
vbo.release()
|
||||
for tex in output_textures:
|
||||
tex.release()
|
||||
fbo.release()
|
||||
|
||||
|
||||
def _prepare_textures(
|
||||
ctx: moderngl.Context,
|
||||
image_list: list[torch.Tensor],
|
||||
batch_idx: int,
|
||||
) -> list[moderngl.Texture]:
|
||||
textures = []
|
||||
for img_tensor in image_list[:MAX_IMAGES]:
|
||||
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
|
||||
img_np = img_tensor[img_idx].cpu().numpy()
|
||||
textures.append(_image_to_texture(ctx, img_np))
|
||||
return textures
|
||||
|
||||
|
||||
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
|
||||
uniforms: dict[str, int | float] = {}
|
||||
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
|
||||
uniforms[f'u_int{i}'] = int(val)
|
||||
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
|
||||
uniforms[f'u_float{i}'] = float(val)
|
||||
return uniforms
|
||||
|
||||
|
||||
def _release_textures(textures: list[moderngl.Texture]) -> None:
|
||||
for texture in textures:
|
||||
texture.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
|
||||
ctx = _create_gl_context(force_software)
|
||||
try:
|
||||
yield ctx
|
||||
finally:
|
||||
ctx.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
|
||||
program = _compile_shader(ctx, fragment_source)
|
||||
try:
|
||||
yield program
|
||||
finally:
|
||||
program.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _textures_context(
|
||||
ctx: moderngl.Context,
|
||||
image_list: list[torch.Tensor],
|
||||
batch_idx: int,
|
||||
) -> Generator[list[moderngl.Texture], None, None]:
|
||||
textures = _prepare_textures(ctx, image_list, batch_idx)
|
||||
try:
|
||||
yield textures
|
||||
finally:
|
||||
_release_textures(textures)
|
||||
|
||||
|
||||
class GLSLShader(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
# Create autogrow templates
|
||||
image_template = io.Autogrow.TemplatePrefix(
|
||||
io.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=MAX_IMAGES,
|
||||
)
|
||||
|
||||
float_template = io.Autogrow.TemplatePrefix(
|
||||
io.Float.Input("float", default=0.0),
|
||||
prefix="u_float",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
int_template = io.Autogrow.TemplatePrefix(
|
||||
io.Int.Input("int", default=0),
|
||||
prefix="u_int",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
category="image/shader",
|
||||
description=(
|
||||
f"Apply GLSL fragment shaders to images. "
|
||||
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
||||
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
||||
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
||||
),
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
default=DEFAULT_FRAGMENT_SHADER,
|
||||
multiline=True,
|
||||
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
||||
),
|
||||
io.DynamicCombo.Input(
|
||||
"size_mode",
|
||||
options=[
|
||||
io.DynamicCombo.Option(
|
||||
"from_input",
|
||||
[], # No extra inputs - uses first input image dimensions
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
"custom",
|
||||
[
|
||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
||||
),
|
||||
io.Autogrow.Input("images", template=image_template),
|
||||
io.Autogrow.Input("floats", template=float_template),
|
||||
io.Autogrow.Input("ints", template=int_template),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0"),
|
||||
io.Image.Output(display_name="IMAGE1"),
|
||||
io.Image.Output(display_name="IMAGE2"),
|
||||
io.Image.Output(display_name="IMAGE3"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
fragment_shader: str,
|
||||
size_mode: SizeModeInput,
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
|
||||
# Determine output dimensions
|
||||
if size_mode["size_mode"] == "custom":
|
||||
out_width, out_height = size_mode["width"], size_mode["height"]
|
||||
else:
|
||||
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
|
||||
|
||||
batch_size = image_list[0].shape[0]
|
||||
uniforms = _prepare_uniforms(int_list, float_list)
|
||||
|
||||
with _gl_context(force_software=args.cpu) as ctx:
|
||||
with _shader_program(ctx, fragment_shader) as program:
|
||||
# Collect outputs for each render target across all batches
|
||||
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
|
||||
|
||||
for b in range(batch_size):
|
||||
with _textures_context(ctx, image_list, b) as textures:
|
||||
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
|
||||
for i, result in enumerate(results):
|
||||
all_outputs[i].append(torch.from_numpy(result))
|
||||
|
||||
# Stack batches for each output
|
||||
output_values = []
|
||||
for i in range(MAX_OUTPUTS):
|
||||
output_batch = torch.stack(all_outputs[i], dim=0)
|
||||
output_values.append(output_batch)
|
||||
|
||||
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
|
||||
|
||||
@classmethod
|
||||
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
|
||||
"""Build UI output with input and output images for client-side shader execution."""
|
||||
combined_inputs = torch.cat(image_list, dim=0)
|
||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||
combined_inputs,
|
||||
filename_prefix="GLSLShader_input",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||
output_batch,
|
||||
filename_prefix="GLSLShader_output",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
return {"input_images": input_images_ui, "images": output_images_ui}
|
||||
|
||||
|
||||
class GLSLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [GLSLShader]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GLSLExtension:
|
||||
return GLSLExtension()
|
||||
@ -38,7 +38,6 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHiDream",
|
||||
search_aliases=["hidream prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -259,7 +259,6 @@ class SetClipHooks:
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
@ -469,7 +468,6 @@ class SetHookKeyframes:
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
@ -499,7 +497,6 @@ class CreateHookKeyframe:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
@ -547,7 +544,6 @@ class CreateHookKeyframesInterpolated:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
@ -622,7 +618,6 @@ class SetModelHooksOnCond:
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
SEARCH_ALIASES = ["merge hooks"]
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
|
||||
@ -618,7 +618,6 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@ -22,7 +22,6 @@ class ImageCrop(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
@ -52,7 +51,6 @@ class RepeatImageBatch(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RepeatImageBatch",
|
||||
search_aliases=["duplicate image", "clone image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -74,7 +72,6 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFromBatch",
|
||||
search_aliases=["select image", "pick from batch", "extract image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -100,7 +97,6 @@ class ImageAddNoise(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageAddNoise",
|
||||
search_aliases=["film grain"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -198,11 +194,11 @@ class SaveAnimatedPNG(IO.ComfyNode):
|
||||
|
||||
class ImageStitch(IO.ComfyNode):
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageStitch",
|
||||
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
||||
display_name="Image Stitch",
|
||||
description="Stitches image2 to image1 in the specified direction.\n"
|
||||
"If image2 is not provided, returns image1 unchanged.\n"
|
||||
@ -373,11 +369,11 @@ class ImageStitch(IO.ComfyNode):
|
||||
|
||||
|
||||
class ResizeAndPadImage(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ResizeAndPadImage",
|
||||
search_aliases=["fit to size"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -424,11 +420,11 @@ class ResizeAndPadImage(IO.ComfyNode):
|
||||
|
||||
|
||||
class SaveSVGNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveSVGNode",
|
||||
search_aliases=["export vector", "save vector graphics"],
|
||||
description="Save SVG files on disk.",
|
||||
category="image/save",
|
||||
inputs=[
|
||||
@ -496,11 +492,11 @@ class SaveSVGNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class GetImageSize(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GetImageSize",
|
||||
search_aliases=["dimensions", "resolution", "image info"],
|
||||
display_name="Get Image Size",
|
||||
description="Returns width and height of the image, and passes it through unchanged.",
|
||||
category="image",
|
||||
@ -531,11 +527,11 @@ class GetImageSize(IO.ComfyNode):
|
||||
|
||||
|
||||
class ImageRotate(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -561,11 +557,11 @@ class ImageRotate(IO.ComfyNode):
|
||||
|
||||
|
||||
class ImageFlip(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFlip",
|
||||
search_aliases=["mirror", "reflect"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
|
||||
@ -104,7 +104,6 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
search_aliases=["kandinsky prompt"],
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -21,7 +21,6 @@ class LatentAdd(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentAdd",
|
||||
search_aliases=["combine latents", "sum latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -48,7 +47,6 @@ class LatentSubtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentSubtract",
|
||||
search_aliases=["difference latent", "remove features"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -75,7 +73,6 @@ class LatentMultiply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentMultiply",
|
||||
search_aliases=["scale latent", "amplify latent", "latent gain"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -99,7 +96,6 @@ class LatentInterpolate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentInterpolate",
|
||||
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -138,7 +134,6 @@ class LatentConcat(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentConcat",
|
||||
search_aliases=["join latents", "stitch latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -178,7 +173,6 @@ class LatentCut(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCut",
|
||||
search_aliases=["crop latent", "slice latent", "extract region"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -219,7 +213,6 @@ class LatentCutToBatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
search_aliases=["slice to batch", "split latent", "tile latent"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -261,7 +254,6 @@ class LatentBatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
search_aliases=["combine latents", "merge latents", "join latents"],
|
||||
category="latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
@ -318,7 +310,6 @@ class LatentApplyOperation(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperation",
|
||||
search_aliases=["transform latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@ -374,7 +365,6 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationTonemapReinhard",
|
||||
search_aliases=["hdr latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -75,7 +75,6 @@ class Preview3D(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Preview3D",
|
||||
search_aliases=["view mesh", "3d viewer"],
|
||||
display_name="Preview 3D & Animation",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
|
||||
@ -104,11 +104,7 @@ class CustomComboNode(io.ComfyNode):
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[
|
||||
io.String.Output(display_name="STRING"),
|
||||
io.Int.Output(display_name="INDEX"),
|
||||
],
|
||||
hidden=["index"],
|
||||
outputs=[io.String.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -119,8 +115,8 @@ class CustomComboNode(io.ComfyNode):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type, index: int = 0) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice, index)
|
||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
@ -228,7 +224,6 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
@ -244,7 +239,6 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
|
||||
@ -78,7 +78,6 @@ class LoraSave(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@ -79,7 +79,6 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
search_aliases=["lumina prompt"],
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
|
||||
@ -50,7 +50,6 @@ class LatentCompositeMasked(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LatentCompositeMasked",
|
||||
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
|
||||
category="latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("destination"),
|
||||
@ -79,7 +78,6 @@ class ImageCompositeMasked(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompositeMasked",
|
||||
search_aliases=["paste image", "overlay", "layer"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
@ -107,7 +105,6 @@ class MaskToImage(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskToImage",
|
||||
search_aliases=["convert mask"],
|
||||
display_name="Convert Mask to Image",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -129,7 +126,6 @@ class ImageToMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageToMask",
|
||||
search_aliases=["extract channel", "channel to mask"],
|
||||
display_name="Convert Image to Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -153,7 +149,6 @@ class ImageColorToMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageColorToMask",
|
||||
search_aliases=["color keying", "chroma key"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -199,7 +194,6 @@ class InvertMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="InvertMask",
|
||||
search_aliases=["reverse mask", "flip mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -220,7 +214,6 @@ class CropMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CropMask",
|
||||
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -246,7 +239,6 @@ class MaskComposite(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskComposite",
|
||||
search_aliases=["combine masks", "blend masks", "layer masks"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("destination"),
|
||||
@ -295,7 +287,6 @@ class FeatherMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FeatherMask",
|
||||
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -342,7 +333,6 @@ class GrowMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrowMask",
|
||||
search_aliases=["expand mask", "shrink mask"],
|
||||
display_name="Grow Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -380,7 +370,6 @@ class ThresholdMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ThresholdMask",
|
||||
search_aliases=["binary mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -405,7 +394,6 @@ class MaskPreview(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskPreview",
|
||||
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
|
||||
display_name="Preview Mask",
|
||||
category="mask",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
|
||||
@ -299,7 +299,6 @@ class RescaleCFG:
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
SEARCH_ALIASES = ["model precision", "change dtype"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
|
||||
@ -91,7 +91,6 @@ class CLIPMergeSimple:
|
||||
|
||||
|
||||
class CLIPSubtract:
|
||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@ -114,7 +113,6 @@ class CLIPSubtract:
|
||||
|
||||
|
||||
class CLIPAdd:
|
||||
SEARCH_ALIASES = ["combine clip"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@ -227,7 +225,6 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@ -340,7 +337,6 @@ class VAESave:
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ class Morphology(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Morphology",
|
||||
search_aliases=["erode", "dilate"],
|
||||
display_name="ImageMorphology",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
@ -58,7 +57,6 @@ class ImageRGBToYUV(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageRGBToYUV",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@ -80,7 +78,6 @@ class ImageYUVToRGB(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageYUVToRGB",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("Y"),
|
||||
|
||||
@ -7,7 +7,6 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodePixArtAlpha",
|
||||
search_aliases=["pixart prompt"],
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
|
||||
@ -402,6 +402,7 @@ def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: st
|
||||
return input[:, y0:y1, x0:x1]
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@ -420,62 +421,46 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||
crop_combo = io.Combo.Input(
|
||||
"crop",
|
||||
options=cls.crop_methods,
|
||||
default="center",
|
||||
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
|
||||
)
|
||||
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||
return io.Schema(
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
description="Resize an image or mask using various scaling methods.",
|
||||
category="transform",
|
||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
io.DynamicCombo.Input(
|
||||
"resize_type",
|
||||
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
|
||||
options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
|
||||
crop_combo,
|
||||
io.DynamicCombo.Input("resize_type", options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
|
||||
crop_combo,
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
],
|
||||
),
|
||||
io.Combo.Input(
|
||||
"scale_method",
|
||||
options=cls.scale_methods,
|
||||
default="area",
|
||||
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
|
||||
),
|
||||
]),
|
||||
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||
],
|
||||
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||
)
|
||||
@ -584,7 +569,6 @@ class BatchMasksNode(io.ComfyNode):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
search_aliases=["combine masks", "stack masks", "merge masks"],
|
||||
display_name="Batch Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -605,7 +589,6 @@ class BatchLatentsNode(io.ComfyNode):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
search_aliases=["combine latents", "stack latents", "merge latents"],
|
||||
display_name="Batch Latents",
|
||||
category="latent",
|
||||
inputs=[
|
||||
@ -629,7 +612,6 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
prefix="input", min=1, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesMasksLatentsNode",
|
||||
search_aliases=["combine batch", "merge batch", "stack inputs"],
|
||||
display_name="Batch Images/Masks/Latents",
|
||||
category="util",
|
||||
inputs=[
|
||||
|
||||
@ -16,7 +16,7 @@ class PreviewAny():
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
|
||||
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
|
||||
@ -65,7 +65,6 @@ class CLIPTextEncodeSD3(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeSD3",
|
||||
search_aliases=["sd3 prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -32,7 +32,6 @@ class StringSubstring(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
search_aliases=["extract text", "text portion"],
|
||||
display_name="Substring",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -55,7 +54,6 @@ class StringLength(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringLength",
|
||||
search_aliases=["character count", "text size"],
|
||||
display_name="Length",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -76,7 +74,6 @@ class CaseConverter(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Case Converter",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -109,7 +106,6 @@ class StringTrim(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
search_aliases=["clean whitespace", "remove whitespace"],
|
||||
display_name="Trim",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -140,7 +136,6 @@ class StringReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
search_aliases=["find and replace", "substitute", "swap text"],
|
||||
display_name="Replace",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -163,7 +158,6 @@ class StringContains(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
search_aliases=["text includes", "string includes"],
|
||||
display_name="Contains",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -191,7 +185,6 @@ class StringCompare(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
search_aliases=["text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Compare",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -227,7 +220,6 @@ class RegexMatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
search_aliases=["pattern match", "text contains", "string match"],
|
||||
display_name="Regex Match",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -268,7 +260,6 @@ class RegexExtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
search_aliases=["pattern extract", "text parser", "parse text"],
|
||||
display_name="Regex Extract",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -343,7 +334,6 @@ class RegexReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
search_aliases=["pattern replace", "find and replace", "substitution"],
|
||||
display_name="Regex Replace",
|
||||
category="utils/string",
|
||||
description="Find and replace text using regex patterns.",
|
||||
|
||||
@ -1101,7 +1101,6 @@ class SaveLoRA(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Save LoRA Weights",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
@ -1145,7 +1144,6 @@ class LossGraphNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode",
|
||||
search_aliases=["training chart", "training visualization", "plot loss"],
|
||||
display_name="Plot Loss Graph",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
|
||||
@ -16,7 +16,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
search_aliases=["export webm"],
|
||||
category="image/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@ -70,7 +69,6 @@ class SaveVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideo",
|
||||
search_aliases=["export video"],
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
@ -118,7 +116,6 @@ class CreateVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CreateVideo",
|
||||
search_aliases=["images to video"],
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
description="Create a video from images.",
|
||||
@ -143,7 +140,6 @@ class GetVideoComponents(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
@ -163,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoSlice",
|
||||
display_name="Video Slice",
|
||||
category="image/video",
|
||||
description="Extract a range of frames from a video.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to slice."),
|
||||
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
|
||||
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(tooltip="The sliced video."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(video.sliced(start_frame, frame_count))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -171,7 +190,6 @@ class LoadVideo(io.ComfyNode):
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return io.Schema(
|
||||
node_id="LoadVideo",
|
||||
search_aliases=["import video", "open video", "video file"],
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
@ -211,6 +229,7 @@ class VideoExtension(ComfyExtension):
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
VideoSlice,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
|
||||
@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
search_aliases=["video conditioning", "video control"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
@ -706,7 +705,6 @@ class WanTrackToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanTrackToVideo",
|
||||
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@ -324,7 +324,6 @@ class GenerateTracks(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
search_aliases=["motion paths", "camera movement", "trajectory"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
|
||||
@ -5,7 +5,6 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
|
||||
class WebcamCapture(nodes.LoadImage):
|
||||
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
@ -192,11 +192,6 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
# Handle custom hidden inputs from prompt data
|
||||
system_hidden_names = {h.name for h in io.Hidden}
|
||||
for hidden_name in hidden:
|
||||
if hidden_name not in system_hidden_names and hidden_name in inputs:
|
||||
input_data_all[hidden_name] = [inputs[hidden_name]]
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
|
||||
42
nodes.py
42
nodes.py
@ -93,8 +93,6 @@ class ConditioningCombine:
|
||||
return (conditioning_1 + conditioning_2, )
|
||||
|
||||
class ConditioningAverage :
|
||||
SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
|
||||
@ -161,8 +159,6 @@ class ConditioningConcat:
|
||||
return (out, )
|
||||
|
||||
class ConditioningSetArea:
|
||||
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -221,8 +217,6 @@ class ConditioningSetAreaStrength:
|
||||
|
||||
|
||||
class ConditioningSetMask:
|
||||
SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -248,8 +242,6 @@ class ConditioningSetMask:
|
||||
return (c, )
|
||||
|
||||
class ConditioningZeroOut:
|
||||
SEARCH_ALIASES = ["null conditioning", "clear conditioning"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", )}}
|
||||
@ -475,8 +467,6 @@ class InpaintModelConditioning:
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
SEARCH_ALIASES = ["export latent"]
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@ -528,8 +518,6 @@ class SaveLatent:
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
SEARCH_ALIASES = ["import latent", "open latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
@ -566,8 +554,6 @@ class LoadLatent:
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
SEARCH_ALIASES = ["load model", "model loader"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
|
||||
@ -607,8 +593,6 @@ class CheckpointLoaderSimple:
|
||||
return out[:3]
|
||||
|
||||
class DiffusersLoader:
|
||||
SEARCH_ALIASES = ["load diffusers model"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
paths = []
|
||||
@ -1079,8 +1063,6 @@ class StyleModelLoader:
|
||||
|
||||
|
||||
class StyleModelApply:
|
||||
SEARCH_ALIASES = ["style transfer"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -1234,8 +1216,6 @@ class EmptyLatentImage:
|
||||
|
||||
|
||||
class LatentFromBatch:
|
||||
SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1268,8 +1248,6 @@ class LatentFromBatch:
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
SEARCH_ALIASES = ["duplicate latent", "clone latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1296,8 +1274,6 @@ class RepeatLatentBatch:
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
SEARCH_ALIASES = ["enlarge latent", "resize latent"]
|
||||
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@ -1332,8 +1308,6 @@ class LatentUpscale:
|
||||
return (s,)
|
||||
|
||||
class LatentUpscaleBy:
|
||||
SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"]
|
||||
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
@ -1377,8 +1351,6 @@ class LatentRotate:
|
||||
return (s,)
|
||||
|
||||
class LatentFlip:
|
||||
SEARCH_ALIASES = ["mirror latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1399,8 +1371,6 @@ class LatentFlip:
|
||||
return (s,)
|
||||
|
||||
class LatentComposite:
|
||||
SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples_to": ("LATENT",),
|
||||
@ -1443,8 +1413,6 @@ class LatentComposite:
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBlend:
|
||||
SEARCH_ALIASES = ["mix latents", "interpolate latents"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
@ -1486,8 +1454,6 @@ class LatentBlend:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
class LatentCrop:
|
||||
SEARCH_ALIASES = ["trim latent", "cut latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1773,8 +1739,6 @@ class LoadImage:
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1825,8 +1789,6 @@ class LoadImageMask:
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
SEARCH_ALIASES = ["output image", "previous generation"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -1900,7 +1862,6 @@ class ImageScaleBy:
|
||||
return (s,)
|
||||
|
||||
class ImageInvert:
|
||||
SEARCH_ALIASES = ["reverse colors"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1916,7 +1877,6 @@ class ImageInvert:
|
||||
return (s,)
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1962,7 +1922,6 @@ class EmptyImage:
|
||||
return (torch.cat((r, g, b), dim=-1), )
|
||||
|
||||
class ImagePadForOutpaint:
|
||||
SEARCH_ALIASES = ["extend canvas", "expand image"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2430,7 +2389,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -28,4 +28,3 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
moderngl
|
||||
|
||||
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import (
|
||||
VideoFromFile,
|
||||
VideoFromComponents,
|
||||
SliceOp,
|
||||
)
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=10, fps=30):
|
||||
"""Helper to create a temporary video file."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
||||
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_file_10_frames():
|
||||
file_path = create_test_video(frames=10)
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components_10_frames():
|
||||
images = torch.rand(10, 4, 4, 3)
|
||||
return VideoComponents(images=images, frame_rate=Fraction(30))
|
||||
|
||||
|
||||
class TestSliceOp:
|
||||
def test_apply_slices_correctly(self, video_components_10_frames):
|
||||
op = SliceOp(start_frame=2, frame_count=3)
|
||||
result = op.apply(video_components_10_frames)
|
||||
|
||||
assert result.images.shape[0] == 3
|
||||
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_compute_frame_count(self):
|
||||
op = SliceOp(start_frame=2, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 5
|
||||
|
||||
def test_compute_frame_count_clamps(self):
|
||||
op = SliceOp(start_frame=8, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 2
|
||||
|
||||
|
||||
class TestVideoSliced:
|
||||
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert video is not sliced
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced._operations) == 1
|
||||
|
||||
def test_get_components_applies_operations(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_get_frame_count(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_get_duration(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(0, 3)
|
||||
|
||||
assert sliced.get_duration() == pytest.approx(0.1)
|
||||
|
||||
def test_chained_slices_compose(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 6).sliced(1, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
||||
|
||||
def test_operations_list_is_immutable(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced1 = video.sliced(0, 5)
|
||||
sliced2 = sliced1.sliced(1, 2)
|
||||
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced1._operations) == 1
|
||||
assert len(sliced2._operations) == 2
|
||||
|
||||
def test_from_file(self, video_file_10_frames):
|
||||
video = VideoFromFile(video_file_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
output_path = str(tmp_path / "sliced_output.mp4")
|
||||
sliced.save_to(output_path)
|
||||
|
||||
saved_video = VideoFromFile(output_path)
|
||||
assert saved_video.get_frame_count() == 3
|
||||
|
||||
def test_materialization_clears_ops(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert len(sliced._operations) == 1
|
||||
sliced.get_components()
|
||||
assert len(sliced._operations) == 0
|
||||
|
||||
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
first = sliced.get_components()
|
||||
second = sliced.get_components()
|
||||
|
||||
assert first.images.shape == second.images.shape
|
||||
assert torch.equal(first.images, second.images)
|
||||
Reference in New Issue
Block a user