mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-06 04:17:48 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ec8c2e7315 | |||
| 2131e630e0 | |||
| 501b808481 | |||
| 7f287b705e |
@ -229,7 +229,7 @@ Python 3.14 works but some custom nodes may have issues. The free threaded varia
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.5 is minimally supported but using a newer version is extremely recommended. Some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. If your pytorch is more than 6 months old, please update it.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -217,10 +217,7 @@ class AceStepAttention(nn.Module):
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
n_rep = self.num_heads // self.num_kv_heads
|
||||
if n_rep > 1:
|
||||
key_states = key_states.repeat_interleave(n_rep, dim=1)
|
||||
value_states = value_states.repeat_interleave(n_rep, dim=1)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
|
||||
attn_bias = None
|
||||
if self.sliding_window is not None and not self.is_cross_attention:
|
||||
@ -244,7 +241,7 @@ class AceStepAttention(nn.Module):
|
||||
else:
|
||||
attn_bias = window_bias
|
||||
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False, **gqa_kwargs)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
@ -425,19 +425,16 @@ class Attention(nn.Module):
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if h != kv_h:
|
||||
# Repeat interleave kv_heads to match q_heads
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
gqa_kwargs = {"enable_gqa": True} if h != kv_h else {}
|
||||
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim=1)
|
||||
k, k_diff = k.unbind(dim=1)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
|
||||
out = self.to_out(out)
|
||||
|
||||
|
||||
@ -74,11 +74,8 @@ class BooguDoubleStreamProcessor(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if attn.kv_heads < attn.heads:
|
||||
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
gqa_kwargs = {"enable_gqa": True} if attn.kv_heads < attn.heads else {}
|
||||
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
|
||||
|
||||
# Split back to instruction/image, apply per-stream output projections, recombine.
|
||||
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import sys
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -14,16 +15,16 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
SAGE_ATTENTION_SUPPORTS_MASK = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
SAGE_ATTENTION_SUPPORTS_MASK = "attn_mask" in inspect.signature(sageattn).parameters
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
@ -89,6 +90,44 @@ def default(val, d):
|
||||
return val
|
||||
return d
|
||||
|
||||
def _gqa_repeat_factor(query_heads, key_heads, value_heads):
|
||||
if key_heads != value_heads:
|
||||
raise ValueError(f"Key/value head count mismatch for GQA: {key_heads} != {value_heads}")
|
||||
if query_heads == key_heads:
|
||||
return 1
|
||||
if query_heads % key_heads != 0:
|
||||
raise ValueError(f"Query heads must be divisible by key/value heads for GQA: {query_heads} vs {key_heads}")
|
||||
return query_heads // key_heads
|
||||
|
||||
def _repeat_kv_for_gqa(k, v, query_heads, head_dim):
|
||||
n_rep = _gqa_repeat_factor(query_heads, k.shape[head_dim], v.shape[head_dim])
|
||||
if n_rep > 1:
|
||||
k = k.repeat_interleave(n_rep, dim=head_dim)
|
||||
v = v.repeat_interleave(n_rep, dim=head_dim)
|
||||
return k, v
|
||||
|
||||
def _heads_from_dim(tensor, dim_head, name):
|
||||
inner_dim = tensor.shape[-1]
|
||||
if inner_dim % dim_head != 0:
|
||||
raise ValueError(f"{name} inner dimension {inner_dim} is not divisible by head dimension {dim_head}")
|
||||
return inner_dim // dim_head
|
||||
|
||||
def _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, enable_gqa=False, expand_kv=True):
|
||||
q = q.unsqueeze(3).reshape(b, -1, heads, dim_head)
|
||||
if enable_gqa:
|
||||
key_heads = _heads_from_dim(k, dim_head, "Key")
|
||||
value_heads = _heads_from_dim(v, dim_head, "Value")
|
||||
else:
|
||||
key_heads = heads
|
||||
value_heads = heads
|
||||
k = k.unsqueeze(3).reshape(b, -1, key_heads, dim_head)
|
||||
v = v.unsqueeze(3).reshape(b, -1, value_heads, dim_head)
|
||||
if enable_gqa:
|
||||
_gqa_repeat_factor(heads, key_heads, value_heads)
|
||||
if expand_kv:
|
||||
k, v = _repeat_kv_for_gqa(k, v, heads, -2)
|
||||
return q, k, v
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@ -152,28 +191,19 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if attn_precision == torch.float32:
|
||||
@ -231,13 +261,16 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
if kwargs.get("enable_gqa", False):
|
||||
key, value = _repeat_kv_for_gqa(key, value, query.shape[-3], -3)
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||
else:
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
query, key, value = _reshape_qkv_to_heads(query, key, value, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
query = query.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
|
||||
dtype = query.dtype
|
||||
@ -304,19 +337,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
@ -438,7 +467,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
@ -446,13 +475,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
(q, k, v),
|
||||
)
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-2], -2)
|
||||
# actually do the reshaping
|
||||
else:
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
|
||||
if mask is not None:
|
||||
# add a singleton batch dimension
|
||||
@ -474,7 +502,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask, scale=kwargs.get("scale", None))
|
||||
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
@ -498,10 +526,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@ -511,9 +537,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_keys = ("scale", "enable_gqa")
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
@ -541,20 +565,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
if kwargs.get("low_precision_attention", True) is False or (mask is not None and not SAGE_ATTENTION_SUPPORTS_MASK):
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
tensor_layout = "NHD"
|
||||
|
||||
if mask is not None:
|
||||
@ -565,8 +588,12 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
sage_kwargs = {"is_causal": False, "tensor_layout": tensor_layout, "sm_scale": kwargs.get("scale", None), "smooth_k": False}
|
||||
if mask is not None:
|
||||
sage_kwargs["attn_mask"] = mask
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
out = sageattn(q, k, v, **sage_kwargs)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
@ -616,7 +643,6 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
@ -642,11 +668,15 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
if skip_reshape:
|
||||
q_s = q
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k_s, v_s = _repeat_kv_for_gqa(k, v, H, -3)
|
||||
else:
|
||||
k_s, v_s = k, v
|
||||
else:
|
||||
q_s, k_s, v_s = _reshape_qkv_to_heads(q, k, v, B, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q_s, k_s, v_s = map(lambda t: t.permute(0, 2, 1, 3).contiguous(), (q_s, k_s, v_s))
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
@ -662,7 +692,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
@ -681,19 +711,20 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
|
||||
softmax_scale_arg = None if softmax_scale == -1.0 else softmax_scale
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal, softmax_scale=softmax_scale_arg)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False, softmax_scale=-1.0):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
@ -703,10 +734,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@ -725,10 +754,16 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=kwargs.get("scale", -1.0),
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
sdpa_extra = {}
|
||||
if kwargs.get("enable_gqa", False):
|
||||
sdpa_extra["enable_gqa"] = True
|
||||
if "scale" in kwargs:
|
||||
sdpa_extra["scale"] = kwargs["scale"]
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -1209,5 +1244,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -141,11 +141,8 @@ class Attention(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if self.kv_heads < self.heads:
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.kv_heads < self.heads else {}
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@ -110,10 +110,6 @@ def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sin
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
@ -550,10 +550,8 @@ class Attention(nn.Module):
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
@ -366,12 +366,8 @@ class GatedAttention(nn.Module):
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
@ -1261,155 +1261,6 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||
class DynamicGroup(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
group_name: str = "Group",
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||
)
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||
# path.split(".") to produce the wrong number of segments during decoding.
|
||||
assert "." not in id, (
|
||||
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||
)
|
||||
for t in template:
|
||||
assert "." not in t.id, (
|
||||
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||
)
|
||||
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||
assert min <= max, "DynamicGroup min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.group_name = group_name
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"group_name": self.group_name,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if present_rows > max_rows:
|
||||
raise ValueError(
|
||||
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||
)
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1567,8 +1418,6 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# DynamicGroup.Input
|
||||
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1580,8 +1429,6 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1923,7 +1770,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1939,10 +1785,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1978,12 +1820,10 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -2006,8 +1846,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -2015,34 +1853,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2607,9 +2417,7 @@ __all__ = [
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"DynamicGroup",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
@ -16,23 +16,30 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
io.Color.Output(display_name="hex")
|
||||
io.Color.Output(display_name="hex"),
|
||||
io.Float.Output(display_name="alpha"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, color: str) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
# expect format #RRGGBB or #RRGGBBAA
|
||||
if len(color) not in (7, 9) or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
|
||||
|
||||
alpha = 1.0
|
||||
if len(color) == 9:
|
||||
alpha = int(color[7:9], 16) / 255.0
|
||||
color = color[:7]
|
||||
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
return io.NodeOutput(rgb_int, color, alpha)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
|
||||
@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def _load_lora_file(lora_name: str):
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
return comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||
|
||||
|
||||
def _lora_template() -> list[io.Input]:
|
||||
return [
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"),
|
||||
tooltip="The name of the LoRA file to apply."),
|
||||
io.Float.Input("strength", default=1.0, min=-100.0, max=100.0, step=0.01,
|
||||
tooltip="How strongly to apply this LoRA. 0 = off, negative inverts the effect."),
|
||||
]
|
||||
|
||||
|
||||
class LoadLoraModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadLoraModel",
|
||||
display_name="Load LoRA (Model)",
|
||||
search_aliases=["lora", "load lora", "apply lora", "lora model", "lora stack"],
|
||||
category="model/loaders",
|
||||
description="Apply a stack of LoRAs to a diffusion model. Add one row per LoRA; "
|
||||
"each row picks a LoRA file and its strength.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The diffusion model the LoRAs will be applied to."),
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=_lora_template(),
|
||||
min=1,
|
||||
max=50,
|
||||
tooltip="Each row applies one LoRA to the model.",
|
||||
group_name="LoRA",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output(tooltip="The modified diffusion model.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, loras: list[dict]) -> io.NodeOutput:
|
||||
for row in loras:
|
||||
lora_name = row.get("lora_name")
|
||||
strength = row.get("strength", 1.0)
|
||||
if not lora_name or lora_name == "none" or strength == 0:
|
||||
continue
|
||||
lora, metadata = _load_lora_file(lora_name)
|
||||
model, _ = comfy.sd.load_lora_for_models(model, None, lora, strength, 0, lora_metadata=metadata)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class LoadLoraTextEncoder(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadLoraTextEncoder",
|
||||
display_name="Load LoRA (Text Encoder)",
|
||||
search_aliases=["lora", "load lora", "apply lora", "clip lora", "lora stack"],
|
||||
category="model/loaders",
|
||||
description="Apply a stack of LoRAs to a CLIP text encoder. Add one row per LoRA; "
|
||||
"each row picks a LoRA file and its strength.",
|
||||
inputs=[
|
||||
io.Clip.Input("clip", tooltip="The CLIP text encoder the LoRAs will be applied to."),
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=_lora_template(),
|
||||
min=1,
|
||||
max=50,
|
||||
tooltip="Each row applies one LoRA to the text encoder.",
|
||||
group_name="LoRA",
|
||||
),
|
||||
],
|
||||
outputs=[io.Clip.Output(tooltip="The modified CLIP text encoder.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, loras: list[dict]) -> io.NodeOutput:
|
||||
for row in loras:
|
||||
lora_name = row.get("lora_name")
|
||||
strength = row.get("strength", 1.0)
|
||||
if not lora_name or lora_name == "none" or strength == 0:
|
||||
continue
|
||||
lora, metadata = _load_lora_file(lora_name)
|
||||
_, clip = comfy.sd.load_lora_for_models(None, clip, lora, 0, strength, lora_metadata=metadata)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
class LoraStackExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadLoraModel,
|
||||
LoadLoraTextEncoder,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LoraStackExtension:
|
||||
return LoraStackExtension()
|
||||
1
nodes.py
1
nodes.py
@ -2502,7 +2502,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
"nodes_seed.py",
|
||||
"nodes_lora_stack.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,204 +0,0 @@
|
||||
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
DynamicGroup,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([group_input])
|
||||
|
||||
|
||||
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(group_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicGroupInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
Reference in New Issue
Block a user