Compare commits

..

7 Commits

15 changed files with 166 additions and 159 deletions

View File

@ -4,12 +4,12 @@ early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
profile: "assertive"
request_changes_workflow: true
high_level_summary: false
poem: false
review_status: false
review_details: false
review_details: true
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
@ -39,6 +39,14 @@ reviews:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Treat AGENTS.md as mandatory repository policy, not optional style guidance.
Flag PR changes that violate AGENTS.md even when the code is otherwise functional.
In particular, enforce architecture boundaries, dtype/device/memory rules,
interface contracts, import style, no unnecessary try/except blocks, no inline
imports, no outbound internet paths in core ComfyUI, and narrow scoped fixes.
Prefer direct findings over suggestions when a rule is violated. Only ignore
AGENTS.md when it clearly conflicts with a newer explicit maintainer instruction
in the PR.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
@ -123,5 +131,10 @@ chat:
knowledge_base:
opt_out: false
code_guidelines:
enabled: true
filePatterns:
- files: "AGENTS.md"
applyTo: "**"
learnings:
scope: "auto"

1
CLAUDE.md Symbolic link
View File

@ -0,0 +1 @@
AGENTS.md

View File

@ -229,7 +229,7 @@ Python 3.14 works but some custom nodes may have issues. The free threaded varia
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
torch 2.5 is minimally supported but using a newer version is extremely recommended. Some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. If your pytorch is more than 6 months old, please update it.
### Instructions:

View File

@ -217,10 +217,7 @@ class AceStepAttention(nn.Module):
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
n_rep = self.num_heads // self.num_kv_heads
if n_rep > 1:
key_states = key_states.repeat_interleave(n_rep, dim=1)
value_states = value_states.repeat_interleave(n_rep, dim=1)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
attn_bias = None
if self.sliding_window is not None and not self.is_cross_attention:
@ -244,7 +241,7 @@ class AceStepAttention(nn.Module):
else:
attn_bias = window_bias
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False, **gqa_kwargs)
attn_output = self.o_proj(attn_output)
return attn_output

View File

@ -425,19 +425,16 @@ class Attention(nn.Module):
if n == 1 and causal:
causal = False
if h != kv_h:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
gqa_kwargs = {"enable_gqa": True} if h != kv_h else {}
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
out = self.to_out(out)

View File

@ -74,11 +74,8 @@ class BooguDoubleStreamProcessor(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if attn.kv_heads < attn.heads:
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
gqa_kwargs = {"enable_gqa": True} if attn.kv_heads < attn.heads else {}
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
# Split back to instruction/image, apply per-stream output projections, recombine.
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])

View File

@ -1,5 +1,6 @@
import math
import sys
import inspect
import torch
import torch.nn.functional as F
@ -14,16 +15,16 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False
SAGE_ATTENTION_SUPPORTS_MASK = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
SAGE_ATTENTION_SUPPORTS_MASK = "attn_mask" in inspect.signature(sageattn).parameters
except ImportError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
@ -89,6 +90,44 @@ def default(val, d):
return val
return d
def _gqa_repeat_factor(query_heads, key_heads, value_heads):
if key_heads != value_heads:
raise ValueError(f"Key/value head count mismatch for GQA: {key_heads} != {value_heads}")
if query_heads == key_heads:
return 1
if query_heads % key_heads != 0:
raise ValueError(f"Query heads must be divisible by key/value heads for GQA: {query_heads} vs {key_heads}")
return query_heads // key_heads
def _repeat_kv_for_gqa(k, v, query_heads, head_dim):
n_rep = _gqa_repeat_factor(query_heads, k.shape[head_dim], v.shape[head_dim])
if n_rep > 1:
k = k.repeat_interleave(n_rep, dim=head_dim)
v = v.repeat_interleave(n_rep, dim=head_dim)
return k, v
def _heads_from_dim(tensor, dim_head, name):
inner_dim = tensor.shape[-1]
if inner_dim % dim_head != 0:
raise ValueError(f"{name} inner dimension {inner_dim} is not divisible by head dimension {dim_head}")
return inner_dim // dim_head
def _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, enable_gqa=False, expand_kv=True):
q = q.unsqueeze(3).reshape(b, -1, heads, dim_head)
if enable_gqa:
key_heads = _heads_from_dim(k, dim_head, "Key")
value_heads = _heads_from_dim(v, dim_head, "Value")
else:
key_heads = heads
value_heads = heads
k = k.unsqueeze(3).reshape(b, -1, key_heads, dim_head)
v = v.unsqueeze(3).reshape(b, -1, value_heads, dim_head)
if enable_gqa:
_gqa_repeat_factor(heads, key_heads, value_heads)
if expand_kv:
k, v = _repeat_kv_for_gqa(k, v, heads, -2)
return q, k, v
# feedforward
class GEGLU(nn.Module):
@ -152,28 +191,19 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
q, k, v = map(
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
# force cast to fp32 to avoid overflowing
if attn_precision == torch.float32:
@ -231,13 +261,16 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
if kwargs.get("enable_gqa", False):
key, value = _repeat_kv_for_gqa(key, value, query.shape[-3], -3)
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
else:
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
query, key, value = _reshape_qkv_to_heads(query, key, value, b, heads, dim_head, kwargs.get("enable_gqa", False))
query = query.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
@ -304,19 +337,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@ -438,7 +467,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@ -446,13 +475,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
lambda t: t.permute(0, 2, 1, 3),
(q, k, v),
)
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-2], -2)
# actually do the reshaping
else:
dim_head //= heads
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
if mask is not None:
# add a singleton batch dimension
@ -474,7 +502,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask, scale=kwargs.get("scale", None))
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
@ -498,10 +526,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
if mask is not None:
# add a batch dimension if there isn't already one
@ -511,9 +537,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_keys = ("scale", "enable_gqa")
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
@ -541,20 +565,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
if kwargs.get("low_precision_attention", True) is False or (mask is not None and not SAGE_ATTENTION_SUPPORTS_MASK):
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
if kwargs.get("enable_gqa", False):
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
tensor_layout = "NHD"
if mask is not None:
@ -565,8 +588,12 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
sage_kwargs = {"is_causal": False, "tensor_layout": tensor_layout, "sm_scale": kwargs.get("scale", None), "smooth_k": False}
if mask is not None:
sage_kwargs["attn_mask"] = mask
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
out = sageattn(q, k, v, **sage_kwargs)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
@ -616,7 +643,6 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
@ -642,11 +668,15 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
if skip_reshape:
q_s = q
if kwargs.get("enable_gqa", False):
k_s, v_s = _repeat_kv_for_gqa(k, v, H, -3)
else:
k_s, v_s = k, v
else:
q_s, k_s, v_s = _reshape_qkv_to_heads(q, k, v, B, heads, dim_head, kwargs.get("enable_gqa", False))
q_s, k_s, v_s = map(lambda t: t.permute(0, 2, 1, 3).contiguous(), (q_s, k_s, v_s))
B, H, L, D = q_s.shape
try:
@ -662,7 +692,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
@ -681,19 +711,20 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
softmax_scale_arg = None if softmax_scale == -1.0 else softmax_scale
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal, softmax_scale=softmax_scale_arg)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False, softmax_scale=-1.0):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
@ -703,10 +734,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
if mask is not None:
# add a batch dimension if there isn't already one
@ -725,10 +754,16 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
softmax_scale=kwargs.get("scale", -1.0),
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
sdpa_extra = {}
if kwargs.get("enable_gqa", False):
sdpa_extra["enable_gqa"] = True
if "scale" in kwargs:
sdpa_extra["scale"] = kwargs["scale"]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -1209,5 +1244,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -141,11 +141,8 @@ class Attention(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
gqa_kwargs = {"enable_gqa": True} if self.kv_heads < self.heads else {}
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
hidden_states = self.to_out[0](hidden_states)
return hidden_states

View File

@ -543,18 +543,24 @@ class SDTokenizer:
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
match = re.search(r'[<\[]', embedding_name)
if match is not None:
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
embedding_name = embedding_name[:match.start()]
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, embedding_name, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
@ -585,7 +591,7 @@ class SDTokenizer:
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
@ -595,7 +601,7 @@ class SDTokenizer:
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
if embed is None:
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:

View File

@ -12,7 +12,7 @@ import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@ -110,10 +110,6 @@ def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sin
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]

View File

@ -550,10 +550,8 @@ class Attention(nn.Module):
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
return self.o_proj(output), present_key_value
class MLP(nn.Module):

View File

@ -366,12 +366,8 @@ class GatedAttention(nn.Module):
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value

View File

@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -529,38 +528,6 @@ class RAMPressureCache(LRUCache):
if psutil.virtual_memory().available >= target:
return
def remove_cache_key(key):
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
def has_old_model_patcher(outputs):
if outputs is None:
return False
for output in outputs:
if isinstance(output, (list, tuple)):
if has_old_model_patcher(output):
return True
elif isinstance(output, ModelPatcher):
return True
return False
old_modelpatcher_keys = []
for key, cache_entry in self.cache.items():
if self.used_generation[key] == self.generation:
continue
if has_old_model_patcher(cache_entry.outputs):
old_modelpatcher_keys.append(key)
for key in old_modelpatcher_keys:
remove_cache_key(key)
if old_modelpatcher_keys:
gc.collect()
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, cache_entry in self.cache.items():
@ -578,17 +545,19 @@ class RAMPressureCache(LRUCache):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size()
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
#old ModelPatchers are the first to go
ram_usage = 1e30
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
to_free = target - psutil.virtual_memory().available
while to_free > 0 and clean_list:
_, _, ram_usage, key = clean_list.pop()
remove_cache_key(key)
to_free -= ram_usage
gc.collect()
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -16,23 +16,30 @@ class ColorToRGBInt(io.ComfyNode):
],
outputs=[
io.Int.Output(display_name="rgb_int"),
io.Color.Output(display_name="hex")
io.Color.Output(display_name="hex"),
io.Float.Output(display_name="alpha"),
],
)
@classmethod
def execute(cls, color: str) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
# expect format #RRGGBB or #RRGGBBAA
if len(color) not in (7, 9) or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
try:
int(color[1:], 16)
except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
alpha = 1.0
if len(color) == 9:
alpha = int(color[7:9], 16) / 255.0
color = color[:7]
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)
return io.NodeOutput(rgb_int, color, alpha)
class ColorExtension(ComfyExtension):

View File

@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]