mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-06 04:17:48 +08:00
Compare commits
7 Commits
mem_fix_at
...
temp_pr
| Author | SHA1 | Date | |
|---|---|---|---|
| ec8c2e7315 | |||
| 2131e630e0 | |||
| 501b808481 | |||
| 7f287b705e | |||
| b7ba504e06 | |||
| 6c62ca0b6b | |||
| 3fe9f5fecb |
@ -4,12 +4,12 @@ early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
profile: "assertive"
|
||||
request_changes_workflow: true
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: false
|
||||
review_details: true
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
@ -39,6 +39,14 @@ reviews:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Treat AGENTS.md as mandatory repository policy, not optional style guidance.
|
||||
Flag PR changes that violate AGENTS.md even when the code is otherwise functional.
|
||||
In particular, enforce architecture boundaries, dtype/device/memory rules,
|
||||
interface contracts, import style, no unnecessary try/except blocks, no inline
|
||||
imports, no outbound internet paths in core ComfyUI, and narrow scoped fixes.
|
||||
Prefer direct findings over suggestions when a rule is violated. Only ignore
|
||||
AGENTS.md when it clearly conflicts with a newer explicit maintainer instruction
|
||||
in the PR.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
@ -123,5 +131,10 @@ chat:
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
code_guidelines:
|
||||
enabled: true
|
||||
filePatterns:
|
||||
- files: "AGENTS.md"
|
||||
applyTo: "**"
|
||||
learnings:
|
||||
scope: "auto"
|
||||
|
||||
@ -229,7 +229,7 @@ Python 3.14 works but some custom nodes may have issues. The free threaded varia
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.5 is minimally supported but using a newer version is extremely recommended. Some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. If your pytorch is more than 6 months old, please update it.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -217,10 +217,7 @@ class AceStepAttention(nn.Module):
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
n_rep = self.num_heads // self.num_kv_heads
|
||||
if n_rep > 1:
|
||||
key_states = key_states.repeat_interleave(n_rep, dim=1)
|
||||
value_states = value_states.repeat_interleave(n_rep, dim=1)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
|
||||
attn_bias = None
|
||||
if self.sliding_window is not None and not self.is_cross_attention:
|
||||
@ -244,7 +241,7 @@ class AceStepAttention(nn.Module):
|
||||
else:
|
||||
attn_bias = window_bias
|
||||
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False, **gqa_kwargs)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
@ -425,19 +425,16 @@ class Attention(nn.Module):
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if h != kv_h:
|
||||
# Repeat interleave kv_heads to match q_heads
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
gqa_kwargs = {"enable_gqa": True} if h != kv_h else {}
|
||||
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim=1)
|
||||
k, k_diff = k.unbind(dim=1)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options, **gqa_kwargs)
|
||||
|
||||
out = self.to_out(out)
|
||||
|
||||
|
||||
@ -74,11 +74,8 @@ class BooguDoubleStreamProcessor(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if attn.kv_heads < attn.heads:
|
||||
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
gqa_kwargs = {"enable_gqa": True} if attn.kv_heads < attn.heads else {}
|
||||
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
|
||||
|
||||
# Split back to instruction/image, apply per-stream output projections, recombine.
|
||||
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import sys
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -14,16 +15,16 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
SAGE_ATTENTION_SUPPORTS_MASK = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
SAGE_ATTENTION_SUPPORTS_MASK = "attn_mask" in inspect.signature(sageattn).parameters
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
@ -89,6 +90,44 @@ def default(val, d):
|
||||
return val
|
||||
return d
|
||||
|
||||
def _gqa_repeat_factor(query_heads, key_heads, value_heads):
|
||||
if key_heads != value_heads:
|
||||
raise ValueError(f"Key/value head count mismatch for GQA: {key_heads} != {value_heads}")
|
||||
if query_heads == key_heads:
|
||||
return 1
|
||||
if query_heads % key_heads != 0:
|
||||
raise ValueError(f"Query heads must be divisible by key/value heads for GQA: {query_heads} vs {key_heads}")
|
||||
return query_heads // key_heads
|
||||
|
||||
def _repeat_kv_for_gqa(k, v, query_heads, head_dim):
|
||||
n_rep = _gqa_repeat_factor(query_heads, k.shape[head_dim], v.shape[head_dim])
|
||||
if n_rep > 1:
|
||||
k = k.repeat_interleave(n_rep, dim=head_dim)
|
||||
v = v.repeat_interleave(n_rep, dim=head_dim)
|
||||
return k, v
|
||||
|
||||
def _heads_from_dim(tensor, dim_head, name):
|
||||
inner_dim = tensor.shape[-1]
|
||||
if inner_dim % dim_head != 0:
|
||||
raise ValueError(f"{name} inner dimension {inner_dim} is not divisible by head dimension {dim_head}")
|
||||
return inner_dim // dim_head
|
||||
|
||||
def _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, enable_gqa=False, expand_kv=True):
|
||||
q = q.unsqueeze(3).reshape(b, -1, heads, dim_head)
|
||||
if enable_gqa:
|
||||
key_heads = _heads_from_dim(k, dim_head, "Key")
|
||||
value_heads = _heads_from_dim(v, dim_head, "Value")
|
||||
else:
|
||||
key_heads = heads
|
||||
value_heads = heads
|
||||
k = k.unsqueeze(3).reshape(b, -1, key_heads, dim_head)
|
||||
v = v.unsqueeze(3).reshape(b, -1, value_heads, dim_head)
|
||||
if enable_gqa:
|
||||
_gqa_repeat_factor(heads, key_heads, value_heads)
|
||||
if expand_kv:
|
||||
k, v = _repeat_kv_for_gqa(k, v, heads, -2)
|
||||
return q, k, v
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@ -152,28 +191,19 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if attn_precision == torch.float32:
|
||||
@ -231,13 +261,16 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
if kwargs.get("enable_gqa", False):
|
||||
key, value = _repeat_kv_for_gqa(key, value, query.shape[-3], -3)
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||
else:
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
query, key, value = _reshape_qkv_to_heads(query, key, value, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
query = query.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
|
||||
dtype = query.dtype
|
||||
@ -304,19 +337,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q, k, v = map(lambda t: t.permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous(), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
@ -438,7 +467,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
@ -446,13 +475,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
(q, k, v),
|
||||
)
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-2], -2)
|
||||
# actually do the reshaping
|
||||
else:
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
|
||||
if mask is not None:
|
||||
# add a singleton batch dimension
|
||||
@ -474,7 +502,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask, scale=kwargs.get("scale", None))
|
||||
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
@ -498,10 +526,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@ -511,9 +537,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_keys = ("scale", "enable_gqa")
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
@ -541,20 +565,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
if kwargs.get("low_precision_attention", True) is False or (mask is not None and not SAGE_ATTENTION_SUPPORTS_MASK):
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k, v = _repeat_kv_for_gqa(k, v, q.shape[-3], -3)
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
tensor_layout = "NHD"
|
||||
|
||||
if mask is not None:
|
||||
@ -565,8 +588,12 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
sage_kwargs = {"is_causal": False, "tensor_layout": tensor_layout, "sm_scale": kwargs.get("scale", None), "smooth_k": False}
|
||||
if mask is not None:
|
||||
sage_kwargs["attn_mask"] = mask
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
out = sageattn(q, k, v, **sage_kwargs)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
@ -616,7 +643,6 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
@ -642,11 +668,15 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
if skip_reshape:
|
||||
q_s = q
|
||||
if kwargs.get("enable_gqa", False):
|
||||
k_s, v_s = _repeat_kv_for_gqa(k, v, H, -3)
|
||||
else:
|
||||
k_s, v_s = k, v
|
||||
else:
|
||||
q_s, k_s, v_s = _reshape_qkv_to_heads(q, k, v, B, heads, dim_head, kwargs.get("enable_gqa", False))
|
||||
q_s, k_s, v_s = map(lambda t: t.permute(0, 2, 1, 3).contiguous(), (q_s, k_s, v_s))
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
@ -662,7 +692,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
@ -681,19 +711,20 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
|
||||
softmax_scale_arg = None if softmax_scale == -1.0 else softmax_scale
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal, softmax_scale=softmax_scale_arg)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False, softmax_scale=-1.0):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
dropout_p: float = 0.0, causal: bool = False, softmax_scale: float = -1.0) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
@ -703,10 +734,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
q, k, v = _reshape_qkv_to_heads(q, k, v, b, heads, dim_head, kwargs.get("enable_gqa", False), expand_kv=False)
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
@ -725,10 +754,16 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=kwargs.get("scale", -1.0),
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
sdpa_extra = {}
|
||||
if kwargs.get("enable_gqa", False):
|
||||
sdpa_extra["enable_gqa"] = True
|
||||
if "scale" in kwargs:
|
||||
sdpa_extra["scale"] = kwargs["scale"]
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -1209,5 +1244,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -141,11 +141,8 @@ class Attention(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if self.kv_heads < self.heads:
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.kv_heads < self.heads else {}
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options, **gqa_kwargs)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -543,18 +543,24 @@ class SDTokenizer:
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
|
||||
'''
|
||||
split_embed = embedding_name.split()
|
||||
embedding_name = split_embed[0]
|
||||
leftover = ' '.join(split_embed[1:])
|
||||
|
||||
match = re.search(r'[<\[]', embedding_name)
|
||||
if match is not None:
|
||||
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
|
||||
embedding_name = embedding_name[:match.start()]
|
||||
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, embedding_name, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
@ -585,7 +591,7 @@ class SDTokenizer:
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment)
|
||||
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
|
||||
to_tokenize = [split[0]]
|
||||
for i in range(1, len(split)):
|
||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||
@ -595,7 +601,7 @@ class SDTokenizer:
|
||||
# if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
|
||||
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@ -110,10 +110,6 @@ def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sin
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
@ -550,10 +550,8 @@ class Attention(nn.Module):
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
@ -366,12 +366,8 @@ class GatedAttention(nn.Module):
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
gqa_kwargs = {"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, **gqa_kwargs)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@ -529,38 +528,6 @@ class RAMPressureCache(LRUCache):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
def remove_cache_key(key):
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
def has_old_model_patcher(outputs):
|
||||
if outputs is None:
|
||||
return False
|
||||
for output in outputs:
|
||||
if isinstance(output, (list, tuple)):
|
||||
if has_old_model_patcher(output):
|
||||
return True
|
||||
elif isinstance(output, ModelPatcher):
|
||||
return True
|
||||
return False
|
||||
|
||||
old_modelpatcher_keys = []
|
||||
for key, cache_entry in self.cache.items():
|
||||
if self.used_generation[key] == self.generation:
|
||||
continue
|
||||
if has_old_model_patcher(cache_entry.outputs):
|
||||
old_modelpatcher_keys.append(key)
|
||||
|
||||
for key in old_modelpatcher_keys:
|
||||
remove_cache_key(key)
|
||||
|
||||
if old_modelpatcher_keys:
|
||||
gc.collect()
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, cache_entry in self.cache.items():
|
||||
@ -578,17 +545,19 @@ class RAMPressureCache(LRUCache):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
|
||||
#old ModelPatchers are the first to go
|
||||
ram_usage = 1e30
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
to_free = target - psutil.virtual_memory().available
|
||||
while to_free > 0 and clean_list:
|
||||
_, _, ram_usage, key = clean_list.pop()
|
||||
remove_cache_key(key)
|
||||
to_free -= ram_usage
|
||||
|
||||
gc.collect()
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
@ -16,23 +16,30 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
io.Color.Output(display_name="hex")
|
||||
io.Color.Output(display_name="hex"),
|
||||
io.Float.Output(display_name="alpha"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, color: str) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
# expect format #RRGGBB or #RRGGBBAA
|
||||
if len(color) not in (7, 9) or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
|
||||
|
||||
alpha = 1.0
|
||||
if len(color) == 9:
|
||||
alpha = int(color[7:9], 16) / 255.0
|
||||
color = color[:7]
|
||||
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
return io.NodeOutput(rgb_int, color, alpha)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
|
||||
2
main.py
2
main.py
@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
|
||||
cache_ram = 0
|
||||
cache_ram_inactive = 0
|
||||
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
|
||||
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
|
||||
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
|
||||
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
|
||||
if len(args.cache_ram) > 0:
|
||||
cache_ram = args.cache_ram[0]
|
||||
|
||||
Reference in New Issue
Block a user