mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 16:26:29 +08:00
Compare commits
5 Commits
cb/video-s
...
jk/node-re
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c0365f6d6 | |||
| 9c7d5f1fdd | |||
| 2c37119ff8 | |||
| 191834c633 | |||
| 5faf2e3cfd |
23
app/node_replace_manager.py
Normal file
23
app/node_replace_manager.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._node_replace import NodeReplace
|
||||
|
||||
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def registered_as_dict():
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
|
||||
}
|
||||
|
||||
class NodeReplaceManager:
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(registered_as_dict())
|
||||
@ -1,11 +1,11 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -42,34 +42,23 @@ class CausalConv3d(nn.Module):
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
tid = threading.get_ident()
|
||||
|
||||
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
|
||||
if cached is None:
|
||||
padding_length = self.time_kernel_size - 1
|
||||
if not causal:
|
||||
padding_length = padding_length // 2
|
||||
if x.shape[2] == 0:
|
||||
return x
|
||||
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
@ -7,35 +6,12 @@ import math
|
||||
from einops import rearrange
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
current = m.temporal_cache_state.get(tid, (None, False))
|
||||
m.temporal_cache_state[tid] = (current[0], True)
|
||||
|
||||
def split2(tensor, split_point, dim=2):
|
||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||
|
||||
def add_exchange_cache(dest, cache_in, new_input, dim=2):
|
||||
if dest is not None:
|
||||
if cache_in is not None:
|
||||
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
|
||||
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
|
||||
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
|
||||
lead_in_dest.add_(lead_in_source)
|
||||
body, new_input = split2(new_input, dest.shape[dim], dim)
|
||||
dest.add_(body)
|
||||
return torch_cat_if_needed([cache_in, new_input], dim=dim)
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
@ -229,7 +205,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
@ -278,22 +254,6 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
tid = threading.get_ident()
|
||||
for _, module in self.named_modules():
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@ -381,6 +341,18 @@ class Decoder(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
@ -456,9 +428,8 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward_orig(
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
@ -466,7 +437,6 @@ class Decoder(nn.Module):
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
checkpoint_fn = (
|
||||
@ -475,12 +445,24 @@ class Decoder(nn.Module):
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
timestep_shift_scale = None
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
@ -501,62 +483,16 @@ class Decoder(nn.Module):
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
|
||||
output = []
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
for _, module in self.named_modules():
|
||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||
#we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
@ -727,22 +663,8 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
tid = threading.get_ident()
|
||||
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
|
||||
y = self.conv(x, causal=causal)
|
||||
y = rearrange(
|
||||
y,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
|
||||
y = y[:, :, 1:, :, :]
|
||||
drop_first_conv = False
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@ -754,20 +676,21 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
drop_first_res = False
|
||||
|
||||
if y.shape[2] == 0:
|
||||
y = None
|
||||
|
||||
cached = add_exchange_cache(y, cached, x_in, dim=2)
|
||||
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
|
||||
|
||||
else:
|
||||
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
|
||||
|
||||
return y
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
@ -884,8 +807,6 @@ class ResnetBlock3D(nn.Module):
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@ -959,12 +880,9 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
tid = threading.get_ident()
|
||||
cached = self.temporal_cache_state.get(tid, None)
|
||||
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
|
||||
self.temporal_cache_state[tid] = cached
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
|
||||
@ -14,13 +14,10 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
elif len(xl) == 1:
|
||||
return xl[0]
|
||||
else:
|
||||
return None
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
|
||||
@ -170,14 +170,8 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
@ -436,9 +430,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
|
||||
@ -1578,9 +1578,6 @@ class QwenImage(BaseModel):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
@ -771,24 +771,10 @@ class Flux2(Flux):
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
detect["model_type"] = "qwen3_4b"
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
|
||||
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
detect["model_type"] = "qwen3_8b"
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
|
||||
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
|
||||
detect["pruned"] = True
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
|
||||
|
||||
return None
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
||||
@ -10,11 +10,9 @@ import comfy.utils
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)]
|
||||
for norm_key in norm_keys:
|
||||
if norm_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[norm_key].dtype
|
||||
break
|
||||
t5_key = "{}model.norm.weight".format(prefix)
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
|
||||
@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from . import _node_replace_public as node_replace
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@ -130,4 +131,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput, VideoOp, SliceOp
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"VideoOp",
|
||||
"SliceOp",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@ -1,48 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import copy
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class VideoOp(ABC):
|
||||
"""Base class for lazy video operations."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceOp(VideoOp):
|
||||
"""Extract a range of frames from the video."""
|
||||
start_frame: int
|
||||
frame_count: int
|
||||
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
total = components.images.shape[0]
|
||||
start = max(0, min(self.start_frame, total))
|
||||
end = min(start + self.frame_count, total)
|
||||
return VideoComponents(
|
||||
images=components.images[start:end],
|
||||
audio=components.audio,
|
||||
frame_rate=components.frame_rate,
|
||||
metadata=getattr(components, 'metadata', None),
|
||||
)
|
||||
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
start = max(0, min(self.start_frame, input_frame_count))
|
||||
return min(self.frame_count, input_frame_count - start)
|
||||
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
@ -58,12 +21,6 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
|
||||
"""Return a copy of this video with a slice operation appended."""
|
||||
new = copy.copy(self)
|
||||
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
|
||||
return new
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
from .._input import SliceOp
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
"SliceOp",
|
||||
]
|
||||
|
||||
@ -3,7 +3,7 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import AudioInput, VideoInput, VideoOp
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
@ -63,8 +63,6 @@ class VideoFromFile(VideoInput):
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self._operations: list[VideoOp] = []
|
||||
self.__materialized: Optional[VideoFromComponents] = None
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -163,10 +161,6 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
|
||||
# Apply operations to get final frame count
|
||||
for op in self._operations:
|
||||
frame_count = op.compute_frame_count(frame_count)
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -245,18 +239,10 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self.__materialized is not None:
|
||||
return self.__materialized.get_components()
|
||||
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
components = self.get_components_internal(container)
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__materialized = VideoFromComponents(components)
|
||||
self._operations = []
|
||||
return components
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
@ -331,27 +317,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
self._operations: list[VideoOp] = []
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self._operations:
|
||||
components = self.__components
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__components = components
|
||||
self._operations = []
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
count = int(self.__components.images.shape[0])
|
||||
for op in self._operations:
|
||||
count = op.compute_frame_count(count)
|
||||
return count
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
@ -359,9 +332,6 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
# Materialize ops before saving
|
||||
components = self.get_components()
|
||||
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@ -375,22 +345,22 @@ class VideoFromComponents(VideoInput):
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = components.images.shape[2]
|
||||
video_stream.height = components.images.shape[1]
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if components.audio:
|
||||
audio_sample_rate = int(components.audio['sample_rate'])
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(components.images):
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
@ -401,9 +371,9 @@ class VideoFromComponents(VideoInput):
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and components.audio:
|
||||
waveform = components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
109
comfy_api/latest/_node_replace.py
Normal file
109
comfy_api/latest/_node_replace.py
Normal file
@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import app.node_replace_manager
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
"""
|
||||
Register node replacement.
|
||||
"""
|
||||
app.node_replace_manager.register_node_replacement(node_replace)
|
||||
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""
|
||||
Create serializable representation of the node replacement.
|
||||
"""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None,
|
||||
"output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class InputMap:
|
||||
"""
|
||||
Map inputs of node replacement.
|
||||
|
||||
Use InputMap.OldId or InputMap.SetValue for mapping purposes.
|
||||
"""
|
||||
class _Assign:
|
||||
def __init__(self, assign_type: str):
|
||||
self.assign_type = assign_type
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"assign_type": self.assign_type,
|
||||
}
|
||||
|
||||
class OldId(_Assign):
|
||||
"""
|
||||
Connect the input of the old node with given id to new node when replacing.
|
||||
"""
|
||||
def __init__(self, old_id: str):
|
||||
super().__init__("old_id")
|
||||
self.old_id = old_id
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"old_id": self.old_id,
|
||||
}
|
||||
|
||||
class SetValue(_Assign):
|
||||
"""
|
||||
Use the given value for the input of the new node when replacing; assumes input is a widget.
|
||||
"""
|
||||
def __init__(self, value: Any):
|
||||
super().__init__("set_value")
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"value": self.value,
|
||||
}
|
||||
|
||||
def __init__(self, new_id: str, assign: OldId | SetValue):
|
||||
self.new_id = new_id
|
||||
self.assign = assign
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_id": self.new_id,
|
||||
"assign": self.assign.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
class OutputMap:
|
||||
"""
|
||||
Map outputs of node replacement via indexes, as that's how outputs are stored.
|
||||
"""
|
||||
def __init__(self, new_idx: int, old_idx: int):
|
||||
self.new_idx = new_idx
|
||||
self.old_idx = old_idx
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_idx": self.new_idx,
|
||||
"old_idx": self.old_idx,
|
||||
}
|
||||
1
comfy_api/latest/_node_replace_public.py
Normal file
1
comfy_api/latest/_node_replace_public.py
Normal file
@ -0,0 +1 @@
|
||||
from ._node_replace import * # noqa: F403
|
||||
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@ -46,4 +46,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode):
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="Load3D",
|
||||
|
||||
@ -159,29 +159,6 @@ class GetVideoComponents(io.ComfyNode):
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoSlice",
|
||||
display_name="Video Slice",
|
||||
category="image/video",
|
||||
description="Extract a range of frames from a video.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to slice."),
|
||||
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
|
||||
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(tooltip="The sliced video."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(video.sliced(start_frame, frame_count))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -229,7 +206,6 @@ class VideoExtension(ComfyExtension):
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
VideoSlice,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -204,6 +205,7 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@ -992,6 +994,7 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
@ -1,150 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import (
|
||||
VideoFromFile,
|
||||
VideoFromComponents,
|
||||
SliceOp,
|
||||
)
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=10, fps=30):
|
||||
"""Helper to create a temporary video file."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
||||
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_file_10_frames():
|
||||
file_path = create_test_video(frames=10)
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components_10_frames():
|
||||
images = torch.rand(10, 4, 4, 3)
|
||||
return VideoComponents(images=images, frame_rate=Fraction(30))
|
||||
|
||||
|
||||
class TestSliceOp:
|
||||
def test_apply_slices_correctly(self, video_components_10_frames):
|
||||
op = SliceOp(start_frame=2, frame_count=3)
|
||||
result = op.apply(video_components_10_frames)
|
||||
|
||||
assert result.images.shape[0] == 3
|
||||
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_compute_frame_count(self):
|
||||
op = SliceOp(start_frame=2, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 5
|
||||
|
||||
def test_compute_frame_count_clamps(self):
|
||||
op = SliceOp(start_frame=8, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 2
|
||||
|
||||
|
||||
class TestVideoSliced:
|
||||
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert video is not sliced
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced._operations) == 1
|
||||
|
||||
def test_get_components_applies_operations(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_get_frame_count(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_get_duration(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(0, 3)
|
||||
|
||||
assert sliced.get_duration() == pytest.approx(0.1)
|
||||
|
||||
def test_chained_slices_compose(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 6).sliced(1, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
||||
|
||||
def test_operations_list_is_immutable(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced1 = video.sliced(0, 5)
|
||||
sliced2 = sliced1.sliced(1, 2)
|
||||
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced1._operations) == 1
|
||||
assert len(sliced2._operations) == 2
|
||||
|
||||
def test_from_file(self, video_file_10_frames):
|
||||
video = VideoFromFile(video_file_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
output_path = str(tmp_path / "sliced_output.mp4")
|
||||
sliced.save_to(output_path)
|
||||
|
||||
saved_video = VideoFromFile(output_path)
|
||||
assert saved_video.get_frame_count() == 3
|
||||
|
||||
def test_materialization_clears_ops(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert len(sliced._operations) == 1
|
||||
sliced.get_components()
|
||||
assert len(sliced._operations) == 0
|
||||
|
||||
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
first = sliced.get_components()
|
||||
second = sliced.get_components()
|
||||
|
||||
assert first.images.shape == second.images.shape
|
||||
assert torch.equal(first.images, second.images)
|
||||
Reference in New Issue
Block a user