mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:56:02 +08:00
Compare commits
8 Commits
jk/remove-
...
pysssss/ba
| Author | SHA1 | Date | |
|---|---|---|---|
| 46a83e9630 | |||
| 5b0fb64d20 | |||
| 521ca3b5d2 | |||
| 53094efd1d | |||
| e89b22993a | |||
| 55bd606e92 | |||
| cc30293d65 | |||
| 866d863128 |
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@ -25,6 +25,10 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libx11-dev
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
|
||||
@ -1251,6 +1251,22 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
input: dict=None
|
||||
output: dict=None
|
||||
hidden: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any = None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadgeDepends:
|
||||
@ -1328,7 +1344,8 @@ class Schema:
|
||||
"""The category of the node, as per the "Add Node" menu."""
|
||||
inputs: list[Input] = field(default_factory=list)
|
||||
outputs: list[Output] = field(default_factory=list)
|
||||
hidden: list[Hidden] = field(default_factory=list)
|
||||
hidden: list[Hidden | str] = field(default_factory=list)
|
||||
"""Hidden inputs. Use Hidden enum for system values (PROMPT, UNIQUE_ID, etc.) or plain strings for custom frontend-provided values."""
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
search_aliases: list[str] = field(default_factory=list)
|
||||
@ -1427,7 +1444,10 @@ class Schema:
|
||||
input = create_input_dict_v1(self.inputs)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
if isinstance(hidden, str):
|
||||
input.setdefault("hidden", {})[hidden] = (hidden,)
|
||||
else:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
# create separate lists from output fields
|
||||
output = []
|
||||
output_is_list = []
|
||||
@ -1474,6 +1494,42 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_v3_info(self, cls) -> NodeInfoV3:
|
||||
input_dict = {}
|
||||
output_dict = {}
|
||||
hidden_list = []
|
||||
# TODO: make sure dynamic types will be handled correctly
|
||||
if self.inputs:
|
||||
for input in self.inputs:
|
||||
add_to_dict_v3(input, input_dict)
|
||||
if self.outputs:
|
||||
for output in self.outputs:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
if isinstance(hidden, str):
|
||||
hidden_list.append(hidden)
|
||||
else:
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
output=output_dict,
|
||||
hidden=hidden_list,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
@ -1528,6 +1584,9 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
as_dict.pop("optional", None)
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
@ -1688,6 +1747,13 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v3_info(cls)
|
||||
return asdict(info)
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
@ -2020,10 +2086,12 @@ __all__ = [
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
"NodeInfoV3",
|
||||
"Schema",
|
||||
"ComfyNode",
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
|
||||
439
comfy_extras/nodes_glsl.py
Normal file
439
comfy_extras/nodes_glsl.py
Normal file
@ -0,0 +1,439 @@
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TypedDict, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.cli_args import args
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import moderngl
|
||||
except ImportError as e:
|
||||
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
|
||||
|
||||
# Default NOOP fragment shader that passes through the input image unchanged
|
||||
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
|
||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
|
||||
in vec2 v_texcoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
fragColor0 = texture(u_image0, v_texcoord);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# Simple vertex shader for full-screen quad
|
||||
VERTEX_SHADER = """#version 330
|
||||
|
||||
in vec2 in_position;
|
||||
in vec2 in_texcoord;
|
||||
|
||||
out vec2 v_texcoord;
|
||||
|
||||
void main() {
|
||||
gl_Position = vec4(in_position, 0.0, 1.0);
|
||||
v_texcoord = in_texcoord;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop_glsl(source: str) -> str:
|
||||
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
|
||||
return re.sub(r'#version\s+300\s+es', '#version 330', source)
|
||||
|
||||
|
||||
def _create_software_gl_context() -> moderngl.Context:
|
||||
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
|
||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
||||
try:
|
||||
ctx = moderngl.create_standalone_context(require=330)
|
||||
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||
return ctx
|
||||
finally:
|
||||
if original_env is None:
|
||||
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
|
||||
else:
|
||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
|
||||
|
||||
|
||||
def _create_gl_context(force_software: bool = False) -> moderngl.Context:
|
||||
if force_software:
|
||||
try:
|
||||
return _create_software_gl_context()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to create software-rendered OpenGL context.\n"
|
||||
"Ensure Mesa/llvmpipe is installed for software rendering support."
|
||||
) from e
|
||||
|
||||
# Try hardware rendering first, fall back to software
|
||||
try:
|
||||
ctx = moderngl.create_standalone_context(require=330)
|
||||
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||
return ctx
|
||||
except Exception as hw_error:
|
||||
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
|
||||
logger.info("Attempting software rendering fallback...")
|
||||
try:
|
||||
return _create_software_gl_context()
|
||||
except Exception as sw_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n"
|
||||
f"Hardware error: {hw_error}\n\n"
|
||||
f"Possible solutions:\n"
|
||||
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
|
||||
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
|
||||
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
|
||||
) from sw_error
|
||||
|
||||
|
||||
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
|
||||
height, width = image.shape[:2]
|
||||
channels = image.shape[2] if len(image.shape) > 2 else 1
|
||||
|
||||
components = min(channels, 4)
|
||||
|
||||
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
||||
|
||||
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
|
||||
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
|
||||
|
||||
texture = ctx.texture((width, height), components, image_uint8.tobytes())
|
||||
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||
texture.repeat_x = False
|
||||
texture.repeat_y = False
|
||||
|
||||
return texture
|
||||
|
||||
|
||||
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
|
||||
width, height = fbo.size
|
||||
|
||||
data = fbo.read(components=channels, attachment=attachment)
|
||||
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
|
||||
|
||||
image = np.ascontiguousarray(np.flipud(image))
|
||||
|
||||
return image.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
|
||||
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
|
||||
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
|
||||
|
||||
try:
|
||||
program = ctx.program(
|
||||
vertex_shader=VERTEX_SHADER,
|
||||
fragment_shader=fragment_source,
|
||||
)
|
||||
return program
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Fragment shader compilation failed.\n\n"
|
||||
"Make sure your shader:\n"
|
||||
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
|
||||
"2. Has valid GLSL ES 3.00 syntax\n"
|
||||
"3. Includes 'precision highp float;' after version\n"
|
||||
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
|
||||
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
|
||||
) from e
|
||||
|
||||
|
||||
def _render_shader(
|
||||
ctx: moderngl.Context,
|
||||
program: moderngl.Program,
|
||||
width: int,
|
||||
height: int,
|
||||
textures: list[moderngl.Texture],
|
||||
uniforms: dict[str, int | float],
|
||||
) -> list[np.ndarray]:
|
||||
# Create output textures
|
||||
output_textures = []
|
||||
for _ in range(MAX_OUTPUTS):
|
||||
tex = ctx.texture((width, height), 4)
|
||||
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||
output_textures.append(tex)
|
||||
|
||||
fbo = ctx.framebuffer(color_attachments=output_textures)
|
||||
|
||||
# Full-screen quad vertices (position + texcoord)
|
||||
vertices = np.array([
|
||||
# Position (x, y), Texcoord (u, v)
|
||||
-1.0, -1.0, 0.0, 0.0,
|
||||
1.0, -1.0, 1.0, 0.0,
|
||||
-1.0, 1.0, 0.0, 1.0,
|
||||
1.0, 1.0, 1.0, 1.0,
|
||||
], dtype='f4')
|
||||
|
||||
vbo = ctx.buffer(vertices.tobytes())
|
||||
vao = ctx.vertex_array(
|
||||
program,
|
||||
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
|
||||
)
|
||||
|
||||
try:
|
||||
# Bind textures
|
||||
for i, texture in enumerate(textures):
|
||||
texture.use(i)
|
||||
uniform_name = f'u_image{i}'
|
||||
if uniform_name in program:
|
||||
program[uniform_name].value = i
|
||||
|
||||
# Set uniforms
|
||||
if 'u_resolution' in program:
|
||||
program['u_resolution'].value = (float(width), float(height))
|
||||
|
||||
for name, value in uniforms.items():
|
||||
if name in program:
|
||||
program[name].value = value
|
||||
|
||||
# Render
|
||||
fbo.use()
|
||||
fbo.clear(0.0, 0.0, 0.0, 1.0)
|
||||
vao.render(moderngl.TRIANGLE_STRIP)
|
||||
|
||||
# Read results from all attachments
|
||||
results = []
|
||||
for i in range(MAX_OUTPUTS):
|
||||
results.append(_texture_to_image(fbo, attachment=i, channels=4))
|
||||
return results
|
||||
finally:
|
||||
vao.release()
|
||||
vbo.release()
|
||||
for tex in output_textures:
|
||||
tex.release()
|
||||
fbo.release()
|
||||
|
||||
|
||||
def _prepare_textures(
|
||||
ctx: moderngl.Context,
|
||||
image_list: list[torch.Tensor],
|
||||
batch_idx: int,
|
||||
) -> list[moderngl.Texture]:
|
||||
textures = []
|
||||
for img_tensor in image_list[:MAX_IMAGES]:
|
||||
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
|
||||
img_np = img_tensor[img_idx].cpu().numpy()
|
||||
textures.append(_image_to_texture(ctx, img_np))
|
||||
return textures
|
||||
|
||||
|
||||
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
|
||||
uniforms: dict[str, int | float] = {}
|
||||
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
|
||||
uniforms[f'u_int{i}'] = int(val)
|
||||
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
|
||||
uniforms[f'u_float{i}'] = float(val)
|
||||
return uniforms
|
||||
|
||||
|
||||
def _release_textures(textures: list[moderngl.Texture]) -> None:
|
||||
for texture in textures:
|
||||
texture.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
|
||||
ctx = _create_gl_context(force_software)
|
||||
try:
|
||||
yield ctx
|
||||
finally:
|
||||
ctx.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
|
||||
program = _compile_shader(ctx, fragment_source)
|
||||
try:
|
||||
yield program
|
||||
finally:
|
||||
program.release()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _textures_context(
|
||||
ctx: moderngl.Context,
|
||||
image_list: list[torch.Tensor],
|
||||
batch_idx: int,
|
||||
) -> Generator[list[moderngl.Texture], None, None]:
|
||||
textures = _prepare_textures(ctx, image_list, batch_idx)
|
||||
try:
|
||||
yield textures
|
||||
finally:
|
||||
_release_textures(textures)
|
||||
|
||||
|
||||
class GLSLShader(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
# Create autogrow templates
|
||||
image_template = io.Autogrow.TemplatePrefix(
|
||||
io.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=MAX_IMAGES,
|
||||
)
|
||||
|
||||
float_template = io.Autogrow.TemplatePrefix(
|
||||
io.Float.Input("float", default=0.0),
|
||||
prefix="u_float",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
int_template = io.Autogrow.TemplatePrefix(
|
||||
io.Int.Input("int", default=0),
|
||||
prefix="u_int",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
category="image/shader",
|
||||
description=(
|
||||
f"Apply GLSL fragment shaders to images. "
|
||||
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
||||
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
||||
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
||||
),
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
default=DEFAULT_FRAGMENT_SHADER,
|
||||
multiline=True,
|
||||
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
||||
),
|
||||
io.DynamicCombo.Input(
|
||||
"size_mode",
|
||||
options=[
|
||||
io.DynamicCombo.Option(
|
||||
"from_input",
|
||||
[], # No extra inputs - uses first input image dimensions
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
"custom",
|
||||
[
|
||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
||||
),
|
||||
io.Autogrow.Input("images", template=image_template),
|
||||
io.Autogrow.Input("floats", template=float_template),
|
||||
io.Autogrow.Input("ints", template=int_template),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0"),
|
||||
io.Image.Output(display_name="IMAGE1"),
|
||||
io.Image.Output(display_name="IMAGE2"),
|
||||
io.Image.Output(display_name="IMAGE3"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
fragment_shader: str,
|
||||
size_mode: SizeModeInput,
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
|
||||
# Determine output dimensions
|
||||
if size_mode["size_mode"] == "custom":
|
||||
out_width, out_height = size_mode["width"], size_mode["height"]
|
||||
else:
|
||||
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
|
||||
|
||||
batch_size = image_list[0].shape[0]
|
||||
uniforms = _prepare_uniforms(int_list, float_list)
|
||||
|
||||
with _gl_context(force_software=args.cpu) as ctx:
|
||||
with _shader_program(ctx, fragment_shader) as program:
|
||||
# Collect outputs for each render target across all batches
|
||||
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
|
||||
|
||||
for b in range(batch_size):
|
||||
with _textures_context(ctx, image_list, b) as textures:
|
||||
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
|
||||
for i, result in enumerate(results):
|
||||
all_outputs[i].append(torch.from_numpy(result))
|
||||
|
||||
# Stack batches for each output
|
||||
output_values = []
|
||||
for i in range(MAX_OUTPUTS):
|
||||
output_batch = torch.stack(all_outputs[i], dim=0)
|
||||
output_values.append(output_batch)
|
||||
|
||||
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
|
||||
|
||||
@classmethod
|
||||
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
|
||||
"""Build UI output with input and output images for client-side shader execution."""
|
||||
combined_inputs = torch.cat(image_list, dim=0)
|
||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||
combined_inputs,
|
||||
filename_prefix="GLSLShader_input",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||
output_batch,
|
||||
filename_prefix="GLSLShader_output",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
return {"input_images": input_images_ui, "images": output_images_ui}
|
||||
|
||||
|
||||
class GLSLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [GLSLShader]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GLSLExtension:
|
||||
return GLSLExtension()
|
||||
@ -104,7 +104,11 @@ class CustomComboNode(io.ComfyNode):
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[io.String.Output()]
|
||||
outputs=[
|
||||
io.String.Output(display_name="STRING"),
|
||||
io.Int.Output(display_name="INDEX"),
|
||||
],
|
||||
hidden=["index"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -115,8 +119,8 @@ class CustomComboNode(io.ComfyNode):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice)
|
||||
def execute(cls, choice: io.Combo.Type, index: int = 0) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice, index)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
|
||||
@ -192,6 +192,11 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
# Handle custom hidden inputs from prompt data
|
||||
system_hidden_names = {h.name for h in io.Hidden}
|
||||
for hidden_name in hidden:
|
||||
if hidden_name not in system_hidden_names and hidden_name in inputs:
|
||||
input_data_all[hidden_name] = [inputs[hidden_name]]
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2430,6 +2430,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -28,3 +28,4 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
moderngl
|
||||
|
||||
Reference in New Issue
Block a user