Compare commits

..

8 Commits

21 changed files with 316 additions and 460 deletions

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy_api.latest._node_replace import NodeReplace
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
def register_node_replacement(node_replace: NodeReplace):
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
def registered_as_dict():
return {
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
}
class NodeReplaceManager:
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(registered_as_dict())

View File

@ -1,11 +1,11 @@
from typing import Tuple, Union from typing import Tuple, Union
import threading
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
def __init__( def __init__(
self, self,
@ -42,34 +42,23 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode, padding_mode=spatial_padding_mode,
groups=groups, groups=groups,
) )
self.temporal_cache_state={}
def forward(self, x, causal: bool = True): def forward(self, x, causal: bool = True):
tid = threading.get_ident() if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
cached, is_end = self.temporal_cache_state.get(tid, (None, False)) (1, 1, self.time_kernel_size - 1, 1, 1)
if cached is None: )
padding_length = self.time_kernel_size - 1 x = torch.concatenate((first_frame_pad, x), dim=2)
if not causal: else:
padding_length = padding_length // 2 first_frame_pad = x[:, :, :1, :, :].repeat(
if x.shape[2] == 0: (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
return x )
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1)) last_frame_pad = x[:, :, -1:, :, :].repeat(
pieces = [ cached, x ] (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
if is_end and not causal: )
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
needs_caching = not is_end return x
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
@property @property
def weight(self): def weight(self):

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
import threading
import torch import torch
from torch import nn from torch import nn
from functools import partial from functools import partial
@ -7,35 +6,12 @@ import math
from einops import rearrange from einops import rearrange
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module): class Encoder(nn.Module):
r""" r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@ -229,7 +205,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class.""" r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@ -278,22 +254,6 @@ class Encoder(nn.Module):
return sample return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module): class Decoder(nn.Module):
r""" r"""
@ -381,6 +341,18 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning, timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y": elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2) output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D( block = ResnetBlock3D(
@ -456,9 +428,8 @@ class Decoder(nn.Module):
) )
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward_orig( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,
@ -466,7 +437,6 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0] batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal) sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = ( checkpoint_fn = (
@ -475,12 +445,24 @@ class Decoder(nn.Module):
else lambda x: x else lambda x: x
) )
timestep_shift_scale = None scaled_timestep = None
if self.timestep_conditioning: if self.timestep_conditioning:
assert ( assert (
timestep is not None timestep is not None
), "should pass timestep with timestep_conditioning=True" ), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder( embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(), timestep=scaled_timestep.flatten(),
resolution=None, resolution=None,
@ -501,62 +483,16 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2], embedded_timestep.shape[-2],
embedded_timestep.shape[-1], embedded_timestep.shape[-1],
) )
timestep_shift_scale = ada_values.unbind(dim=1) shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
output = [] sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module): class UNetMidBlock3D(nn.Module):
""" """
@ -727,22 +663,8 @@ class DepthToSpaceUpsample(nn.Module):
) )
self.residual = residual self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual: if self.residual:
# Reshape and duplicate the input to match the output shape # Reshape and duplicate the input to match the output shape
x_in = rearrange( x_in = rearrange(
@ -754,20 +676,21 @@ class DepthToSpaceUpsample(nn.Module):
) )
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1) x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res: if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :] x_in = x_in[:, :, 1:, :, :]
drop_first_res = False x = self.conv(x, causal=causal)
x = rearrange(
if y.shape[2] == 0: x,
y = None "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
cached = add_exchange_cache(y, cached, x_in, dim=2) p2=self.stride[1],
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res) p3=self.stride[2],
)
else: if self.stride[0] == 2:
self.temporal_cache_state[tid] = (None, drop_first_conv, False) x = x[:, :, 1:, :, :]
if self.residual:
return y x = x + x_in
return x
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None: def __init__(self, dim, eps, elementwise_affine=True) -> None:
@ -884,8 +807,6 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5 torch.randn(4, in_channels) / in_channels**0.5
) )
self.temporal_cache_state={}
def _feed_spatial_noise( def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -959,12 +880,9 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor) input_tensor = self.conv_shortcut(input_tensor)
tid = threading.get_ident() output_tensor = input_tensor + hidden_states
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
return hidden_states return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1): def patchify(x, patch_size_hw, patch_size_t=1):

View File

@ -14,13 +14,10 @@ if model_management.xformers_enabled_vae():
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim): def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1: if len(xl) > 1:
return torch.cat(xl, dim) return torch.cat(xl, dim)
elif len(xl) == 1:
return xl[0]
else: else:
return None return xl[0]
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """

View File

@ -170,14 +170,8 @@ class Attention(nn.Module):
joint_query = apply_rope1(joint_query, image_rotary_emb) joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb) joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options, attention_mask, transformer_options=transformer_options,
skip_reshape=True) skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :] txt_attn_output = joint_hidden_states[:, :seq_txt, :]
@ -436,9 +430,6 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask encoder_hidden_states_mask = attention_mask
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
hidden_states, img_ids, orig_shape = self.process_img(x) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1] num_embeds = hidden_states.shape[1]

View File

@ -1578,9 +1578,6 @@ class QwenImage(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

View File

@ -771,24 +771,10 @@ class Flux2(Flux):
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
if len(detect) > 0: return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
detect["model_type"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_8b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
if len(detect) > 0:
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
detect["pruned"] = True
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
return None
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {

View File

@ -10,11 +10,9 @@ import comfy.utils
def llama_detect(state_dict, prefix=""): def llama_detect(state_dict, prefix=""):
out = {} out = {}
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)] t5_key = "{}model.norm.weight".format(prefix)
for norm_key in norm_keys: if t5_key in state_dict:
if norm_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype
out["dtype_llama"] = state_dict[norm_key].dtype
break
quant = comfy.utils.detect_layer_quantization(state_dict, prefix) quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None: if quant is not None:

View File

@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io from . import _io_public as io
from . import _ui_public as ui from . import _ui_public as ui
from . import _node_replace_public as node_replace
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image from PIL import Image
@ -130,4 +131,5 @@ __all__ = [
"IO", "IO",
"ui", "ui",
"UI", "UI",
"node_replace",
] ]

View File

@ -1,12 +1,10 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput, VideoOp, SliceOp from .video_types import VideoInput
__all__ = [ __all__ = [
"ImageInput", "ImageInput",
"AudioInput", "AudioInput",
"VideoInput", "VideoInput",
"VideoOp",
"SliceOp",
"MaskInput", "MaskInput",
"LatentInput", "LatentInput",
] ]

View File

@ -1,48 +1,11 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import copy
import io import io
import av import av
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoOp(ABC):
"""Base class for lazy video operations."""
@abstractmethod
def apply(self, components: VideoComponents) -> VideoComponents:
pass
@abstractmethod
def compute_frame_count(self, input_frame_count: int) -> int:
pass
@dataclass(frozen=True)
class SliceOp(VideoOp):
"""Extract a range of frames from the video."""
start_frame: int
frame_count: int
def apply(self, components: VideoComponents) -> VideoComponents:
total = components.images.shape[0]
start = max(0, min(self.start_frame, total))
end = min(start + self.frame_count, total)
return VideoComponents(
images=components.images[start:end],
audio=components.audio,
frame_rate=components.frame_rate,
metadata=getattr(components, 'metadata', None),
)
def compute_frame_count(self, input_frame_count: int) -> int:
start = max(0, min(self.start_frame, input_frame_count))
return min(self.frame_count, input_frame_count - start)
class VideoInput(ABC): class VideoInput(ABC):
""" """
Abstract base class for video input types. Abstract base class for video input types.
@ -58,12 +21,6 @@ class VideoInput(ABC):
""" """
pass pass
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
"""Return a copy of this video with a slice operation appended."""
new = copy.copy(self)
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
return new
@abstractmethod @abstractmethod
def save_to( def save_to(
self, self,

View File

@ -1,8 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents from .video_types import VideoFromFile, VideoFromComponents
from .._input import SliceOp
__all__ = [ __all__ = [
# Implementations
"VideoFromFile", "VideoFromFile",
"VideoFromComponents", "VideoFromComponents",
"SliceOp",
] ]

View File

@ -3,7 +3,7 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from .._input import AudioInput, VideoInput, VideoOp from .._input import AudioInput, VideoInput
import av import av
import io import io
import json import json
@ -63,8 +63,6 @@ class VideoFromFile(VideoInput):
containing the file contents. containing the file contents.
""" """
self.__file = file self.__file = file
self._operations: list[VideoOp] = []
self.__materialized: Optional[VideoFromComponents] = None
def get_stream_source(self) -> str | io.BytesIO: def get_stream_source(self) -> str | io.BytesIO:
""" """
@ -163,10 +161,6 @@ class VideoFromFile(VideoInput):
if frame_count == 0: if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'") raise ValueError(f"Could not determine frame count for file '{self.__file}'")
# Apply operations to get final frame count
for op in self._operations:
frame_count = op.compute_frame_count(frame_count)
return frame_count return frame_count
def get_frame_rate(self) -> Fraction: def get_frame_rate(self) -> Fraction:
@ -245,18 +239,10 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if self.__materialized is not None:
return self.__materialized.get_components()
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container: with av.open(self.__file, mode='r') as container:
components = self.get_components_internal(container) return self.get_components_internal(container)
for op in self._operations:
components = op.apply(components)
self.__materialized = VideoFromComponents(components)
self._operations = []
return components
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to( def save_to(
@ -331,27 +317,14 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents): def __init__(self, components: VideoComponents):
self.__components = components self.__components = components
self._operations: list[VideoOp] = []
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if self._operations:
components = self.__components
for op in self._operations:
components = op.apply(components)
self.__components = components
self._operations = []
return VideoComponents( return VideoComponents(
images=self.__components.images, images=self.__components.images,
audio=self.__components.audio, audio=self.__components.audio,
frame_rate=self.__components.frame_rate frame_rate=self.__components.frame_rate
) )
def get_frame_count(self) -> int:
count = int(self.__components.images.shape[0])
for op in self._operations:
count = op.compute_frame_count(count)
return count
def save_to( def save_to(
self, self,
path: str, path: str,
@ -359,9 +332,6 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None
): ):
# Materialize ops before saving
components = self.get_components()
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -375,22 +345,22 @@ class VideoFromComponents(VideoInput):
for key, value in metadata.items(): for key, value in metadata.items():
output.metadata[key] = json.dumps(value) output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(components.frame_rate * 1000), 1000) frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream # Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate) video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = components.images.shape[2] video_stream.width = self.__components.images.shape[2]
video_stream.height = components.images.shape[1] video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p' video_stream.pix_fmt = 'yuv420p'
# Create an audio stream # Create an audio stream
audio_sample_rate = 1 audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None audio_stream: Optional[av.AudioStream] = None
if components.audio: if self.__components.audio:
audio_sample_rate = int(components.audio['sample_rate']) audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate) audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video # Encode video
for i, frame in enumerate(components.images): for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
@ -401,9 +371,9 @@ class VideoFromComponents(VideoInput):
packet = video_stream.encode(None) packet = video_stream.encode(None)
output.mux(packet) output.mux(packet)
if audio_stream and components.audio: if audio_stream and self.__components.audio:
waveform = components.audio['waveform'] waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate frame.sample_rate = audio_sample_rate
frame.pts = 0 frame.pts = 0

View File

@ -0,0 +1,109 @@
from __future__ import annotations
from typing import Any
import app.node_replace_manager
def register_node_replacement(node_replace: NodeReplace):
"""
Register node replacement.
"""
app.node_replace_manager.register_node_replacement(node_replace)
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""
Create serializable representation of the node replacement.
"""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None,
"output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None,
}
class InputMap:
"""
Map inputs of node replacement.
Use InputMap.OldId or InputMap.SetValue for mapping purposes.
"""
class _Assign:
def __init__(self, assign_type: str):
self.assign_type = assign_type
def as_dict(self):
return {
"assign_type": self.assign_type,
}
class OldId(_Assign):
"""
Connect the input of the old node with given id to new node when replacing.
"""
def __init__(self, old_id: str):
super().__init__("old_id")
self.old_id = old_id
def as_dict(self):
return super().as_dict() | {
"old_id": self.old_id,
}
class SetValue(_Assign):
"""
Use the given value for the input of the new node when replacing; assumes input is a widget.
"""
def __init__(self, value: Any):
super().__init__("set_value")
self.value = value
def as_dict(self):
return super().as_dict() | {
"value": self.value,
}
def __init__(self, new_id: str, assign: OldId | SetValue):
self.new_id = new_id
self.assign = assign
def as_dict(self):
return {
"new_id": self.new_id,
"assign": self.assign.as_dict(),
}
class OutputMap:
"""
Map outputs of node replacement via indexes, as that's how outputs are stored.
"""
def __init__(self, new_idx: int, old_idx: int):
self.new_idx = new_idx
self.old_idx = old_idx
def as_dict(self):
return {
"new_idx": self.new_idx,
"old_idx": self.old_idx,
}

View File

@ -0,0 +1 @@
from ._node_replace import * # noqa: F403

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
) )
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -46,4 +46,5 @@ __all__ = [
"IO", "IO",
"ui", "ui",
"UI", "UI",
"node_replace",
] ]

View File

@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode):
files = [ files = [
normalize_path(str(file_path.relative_to(base_path))) normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*") for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'} if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
] ]
return IO.Schema( return IO.Schema(
node_id="Load3D", node_id="Load3D",

View File

@ -637,6 +637,97 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values) batched = batch_masks(values)
return io.NodeOutput(batched) return io.NodeOutput(batched)
from comfy_api.latest import node_replace
def register_replacements():
register_replacements_longeredge()
register_replacements_batchimages()
register_replacements_upscaleimage()
register_replacements_controlnet()
register_replacements_load3d()
register_replacements_preview3d()
register_replacements_svdimg2vid()
register_replacements_conditioningavg()
def register_replacements_longeredge():
# No dynamic inputs here
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
node_replace.InputMap(new_id="image", assign=node_replace.InputMap.OldId("images")),
node_replace.InputMap(new_id="largest_size", assign=node_replace.InputMap.OldId("longer_edge")),
node_replace.InputMap(new_id="upscale_method", assign=node_replace.InputMap.SetValue("lanczos")),
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)],
))
def register_replacements_batchimages():
# BatchImages node uses Autogrow
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
node_replace.InputMap(new_id="images.image0", assign=node_replace.InputMap.OldId("image1")),
node_replace.InputMap(new_id="images.image1", assign=node_replace.InputMap.OldId("image2")),
],
))
def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
node_replace.InputMap(new_id="input", assign=node_replace.InputMap.OldId("image")),
node_replace.InputMap(new_id="resize_type", assign=node_replace.InputMap.SetValue("scale by multiplier")),
node_replace.InputMap(new_id="resize_type.multiplier", assign=node_replace.InputMap.OldId("scale_by")),
node_replace.InputMap(new_id="scale_method", assign=node_replace.InputMap.OldId("upscale_method")),
],
))
def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
node_replace.InputMap(new_id="control_net_name", assign=node_replace.InputMap.OldId("t2i_adapter_name")),
],
))
def register_replacements_load3d():
# Load3DAnimation merged into Load3D
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
node_replace.register_node_replacement(node_replace.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@ -159,29 +159,6 @@ class GetVideoComponents(io.ComfyNode):
return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoSlice",
display_name="Video Slice",
category="image/video",
description="Extract a range of frames from a video.",
inputs=[
io.Video.Input("video", tooltip="The video to slice."),
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
],
outputs=[
io.Video.Output(tooltip="The sliced video."),
],
)
@classmethod
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
return io.NodeOutput(video.sliced(start_frame, frame_count))
class LoadVideo(io.ComfyNode): class LoadVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -229,7 +206,6 @@ class VideoExtension(ComfyExtension):
SaveVideo, SaveVideo,
CreateVideo, CreateVideo,
GetVideoComponents, GetVideoComponents,
VideoSlice,
LoadVideo, LoadVideo,
] ]

View File

@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes from protocol import BinaryEventTypes
@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager() self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager() self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager() self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self) self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self) self.prompt_queue = execution.PromptQueue(self)
@ -992,6 +994,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app()) self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation. # Prefix every route with /api for easier matching for delegation.

View File

@ -1,150 +0,0 @@
import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)