mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 20:57:27 +08:00
Compare commits
1 Commits
master
...
cloud-open
| Author | SHA1 | Date | |
|---|---|---|---|
| 63dc90e6c0 |
@ -15,6 +15,15 @@ import comfy.patcher_extension
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
t: torch.Tensor,
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||||
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||||
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||||
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Feed Forward Network -----------------------
|
# ---------------------- Feed Forward Network -----------------------
|
||||||
class GPT2FeedForward(nn.Module):
|
class GPT2FeedForward(nn.Module):
|
||||||
@ -164,7 +173,8 @@ class Attention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
v = self.v_norm(v)
|
v = self.v_norm(v)
|
||||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
|
q = apply_rotary_pos_emb(q, rope_emb)
|
||||||
|
k = apply_rotary_pos_emb(k, rope_emb)
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||||
|
|||||||
@ -51,6 +51,15 @@ class FeedForward(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x, freqs_cis):
|
||||||
|
if x.shape[1] == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||||
|
return t_out.reshape(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
21614
openapi.yaml
21614
openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy>=2.0.0
|
SQLAlchemy>=2.0.0
|
||||||
filelock
|
filelock
|
||||||
av>=16.0.0
|
av>=16.0.0
|
||||||
comfy-kitchen==0.2.10
|
comfy-kitchen==0.2.9
|
||||||
comfy-aimdo==0.4.5
|
comfy-aimdo==0.4.5
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user