mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 14:47:21 +08:00
Compare commits
3 Commits
add-cla-wo
...
comfyanony
| Author | SHA1 | Date | |
|---|---|---|---|
| 7dacbbdee3 | |||
| 2a61015582 | |||
| 6978a466b8 |
62
.github/workflows/cla.yml
vendored
62
.github/workflows/cla.yml
vendored
@ -1,62 +0,0 @@
|
||||
name: CLA Assistant
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_target:
|
||||
types: [opened, synchronize, closed]
|
||||
|
||||
permissions:
|
||||
actions: write
|
||||
contents: read # 'read' is enough because signatures live in a REMOTE repo
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
cla-assistant:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: CLA Assistant
|
||||
# Run on PR events, on "recheck" comment, or when someone posts the exact signing phrase.
|
||||
# IMPORTANT: this phrase must match `custom-pr-sign-comment` below.
|
||||
if: >
|
||||
github.event_name == 'pull_request_target' ||
|
||||
github.event.comment.body == 'recheck' ||
|
||||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
|
||||
uses: contributor-assistant/github-action@v2.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# PAT required to write to the centralized signatures repo.
|
||||
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
|
||||
with:
|
||||
# Where the CLA document lives (shown to contributors)
|
||||
path-to-document: https://github.com/Comfy-Org/comfy-cla/blob/main/comfyui_icla.md
|
||||
|
||||
# Centralized signature storage
|
||||
remote-organization-name: comfy-org
|
||||
remote-repository-name: comfy-cla
|
||||
path-to-signatures: signatures/cla.json
|
||||
branch: main
|
||||
|
||||
# Allowlist bots so they don't need to sign (optional, comma-separated).
|
||||
# *[bot] is a catch-all for any GitHub App bot account.
|
||||
allowlist: ampagent,claude,coderabbitai[bot],comfy-pr-bot,dependabot[bot],github-actions[bot],copilot-swe-agent[bot],devin-ai-integration[bot],*[bot]
|
||||
|
||||
# Custom PR comment messages
|
||||
custom-notsigned-prcomment: |
|
||||
🎉 Thank you for your contribution, we really appreciate it! 🎉
|
||||
|
||||
Like many open source projects, we require contributors to sign our [Contributor License Agreement (CLA)](https://github.com/Comfy-Org/comfy-cla/blob/main/comfyui_icla.md). A CLA makes the ownership of contributions explicit, so contributors and the project share a clear understanding of how the code can be used. By signing, you:
|
||||
|
||||
- Confirm that you own your contribution.
|
||||
- Keep the right to reuse your own code.
|
||||
- Grant us a copyright license to include and share it within our projects.
|
||||
|
||||
CLAs are standard practice across major open source projects including those under the Apache Software Foundation and the Linux Foundation. Ours is based on the Apache Software Foundation's CLA. Most importantly, it would enable us to relicense the project under a more permissive license in the future, giving the project and its community greater flexibility.
|
||||
|
||||
✍ **To sign, please post a new comment on this PR with exactly the following text:** ✍
|
||||
|
||||
custom-pr-sign-comment: I have read and agree to the Contributor License Agreement
|
||||
|
||||
custom-allsigned-prcomment: |
|
||||
✅ All contributors have signed the CLA. Thank you! This PR is ready to be merged.
|
||||
290
comfy/ldm/krea2/model.py
Normal file
290
comfy/ldm/krea2/model.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""Krea 2 (K2) — single-stream MMDiT.
|
||||
|
||||
Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are
|
||||
concatenated into one sequence and run through ``layers`` shared transformer blocks with
|
||||
AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import EmbedND, timestep_embedding
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered)."""
|
||||
|
||||
def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0
|
||||
return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
|
||||
self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, q, k):
|
||||
return self.qnorm(q), self.knorm(k)
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
mlpdim = int(2 * features / 3) * multiplier
|
||||
mlpdim = multiple * ((mlpdim + multiple - 1) // multiple)
|
||||
self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
|
||||
self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
|
||||
self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down(F.silu(self.gate(x)).mul_(self.up(x)))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kvheads = kvheads if kvheads is not None else heads
|
||||
self.headdim = dim // self.heads
|
||||
self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype)
|
||||
self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
|
||||
self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
|
||||
self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
|
||||
self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations)
|
||||
self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs=None, mask=None, transformer_options={}):
|
||||
q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x)
|
||||
q = rearrange(q, "B L (H D) -> B H L D", H=self.heads)
|
||||
k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads)
|
||||
v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads)
|
||||
q, k = self.qknorm(q, k)
|
||||
if freqs is not None:
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
if self.kvheads != self.heads:
|
||||
rep = self.heads // self.kvheads
|
||||
k = k.repeat_interleave(rep, dim=1)
|
||||
v = v.repeat_interleave(rep, dim=1)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True,
|
||||
transformer_options=transformer_options)
|
||||
return self.wo(out * F.sigmoid(gate))
|
||||
|
||||
|
||||
class SimpleModulation(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, vec):
|
||||
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0)
|
||||
scale, shift = out.chunk(2, dim=1)
|
||||
return scale, shift
|
||||
|
||||
|
||||
class DoubleSharedModulation(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, vec):
|
||||
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device)
|
||||
return out.chunk(6, dim=-1)
|
||||
|
||||
|
||||
class TextFusionBlock(nn.Module):
|
||||
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, mask=None, transformer_options={}):
|
||||
x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.postnorm(x))
|
||||
return x
|
||||
|
||||
|
||||
class TextFusionTransformer(nn.Module):
|
||||
def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layerwise_blocks = nn.ModuleList([
|
||||
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(2)
|
||||
])
|
||||
self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype)
|
||||
self.refiner_blocks = nn.ModuleList([
|
||||
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(2)
|
||||
])
|
||||
|
||||
def forward(self, x, mask=None, transformer_options={}):
|
||||
b, l, n, d = x.shape
|
||||
x = x.reshape(b * l, n, d)
|
||||
for block in self.layerwise_blocks:
|
||||
x = block(x.contiguous(), mask=None, transformer_options=transformer_options)
|
||||
x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l)
|
||||
x = self.projector(x).squeeze(-1)
|
||||
for block in self.refiner_blocks:
|
||||
x = block(x, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations)
|
||||
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, vec, freqs, mask=None, transformer_options={}):
|
||||
prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec)
|
||||
x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options)
|
||||
x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, features, patch, channels, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype)
|
||||
self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, tvec):
|
||||
scale, shift = self.modulation(tvec)
|
||||
x = (1 + scale) * self.norm(x) + shift
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class SingleStreamDiT(nn.Module):
|
||||
def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4,
|
||||
layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12,
|
||||
txtheads=20, txtkvheads=20, image_model=None,
|
||||
device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch = patch
|
||||
self.channels = channels
|
||||
self.tdim = tdim
|
||||
self.heads = heads
|
||||
self.txtdim = txtdim
|
||||
self.txtlayers = txtlayers
|
||||
|
||||
headdim = features // heads
|
||||
axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)]
|
||||
assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}"
|
||||
self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes)
|
||||
|
||||
self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(layers)
|
||||
])
|
||||
self.tmlp = nn.Sequential(
|
||||
operations.Linear(tdim, features, device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features, device=device, dtype=dtype),
|
||||
)
|
||||
self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.txtmlp = nn.Sequential(
|
||||
RMSNorm(txtdim, device=device, dtype=dtype, operations=operations),
|
||||
operations.Linear(txtdim, features, device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features, device=device, dtype=dtype),
|
||||
)
|
||||
self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations)
|
||||
self.tproj = nn.Sequential(
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features * 6, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
|
||||
temporal = x.ndim == 5
|
||||
if temporal:
|
||||
b5, c5, t5, h5, w5 = x.shape
|
||||
x = x.reshape(b5 * t5, c5, h5, w5)
|
||||
bs, c, H_orig, W_orig = x.shape
|
||||
patch = self.patch
|
||||
# Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end.
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch))
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
h_, w_ = H // patch, W // patch
|
||||
|
||||
# context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim).
|
||||
context = self._unpack_context(context)
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch)
|
||||
img = self.first(img)
|
||||
|
||||
t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype))
|
||||
tvec = self.tproj(t)
|
||||
|
||||
context = self.txtfusion(context, mask=None, transformer_options=transformer_options)
|
||||
context = self.txtmlp(context)
|
||||
|
||||
txtlen, imglen = context.shape[1], img.shape[1]
|
||||
combined = torch.cat((context, img), dim=1)
|
||||
|
||||
# Position ids: text at 0, image at (0, h_idx, w_idx).
|
||||
device = combined.device
|
||||
txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32)
|
||||
imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32)
|
||||
imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None]
|
||||
imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :]
|
||||
imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1)
|
||||
pos = torch.cat((txtpos, imgpos), dim=1)
|
||||
|
||||
freqs = self.pe_embedder(pos)
|
||||
|
||||
for block in self.blocks:
|
||||
combined = block(combined, tvec, freqs, None, transformer_options=transformer_options)
|
||||
|
||||
final = self.last(combined, t)
|
||||
out = final[:, txtlen:txtlen + imglen, :]
|
||||
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=h_, w=w_, ph=patch, pw=patch, c=self.channels)
|
||||
out = out[:, :, :H_orig, :W_orig] # crop padding back off
|
||||
if temporal:
|
||||
out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2)
|
||||
return out
|
||||
|
||||
def _unpack_context(self, context):
|
||||
# context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim).
|
||||
b, seq, fused = context.shape
|
||||
if fused != self.txtlayers * self.txtdim:
|
||||
raise ValueError(
|
||||
f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} "
|
||||
f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. "
|
||||
f"Load the text encoder with CLIPLoader type 'krea2'."
|
||||
)
|
||||
return context.reshape(b, seq, self.txtlayers, self.txtdim)
|
||||
@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Krea2):
|
||||
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
|
||||
@ -58,6 +58,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.boogu.model
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.krea2.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
@ -2278,6 +2279,17 @@ class Ideogram4(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class Krea2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -834,6 +834,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "krea2"
|
||||
head_dim = 128
|
||||
first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2)
|
||||
dit_config["features"] = first_w.shape[0]
|
||||
dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2
|
||||
dit_config["patch"] = 2
|
||||
dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
|
||||
@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.krea2
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
@ -1303,6 +1304,7 @@ class CLIPType(Enum):
|
||||
PIXELDIT = 29
|
||||
IDEOGRAM4 = 30
|
||||
BOOGU = 31
|
||||
KREA2 = 32
|
||||
|
||||
|
||||
|
||||
@ -1628,6 +1630,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
|
||||
elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer
|
||||
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
|
||||
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)
|
||||
|
||||
@ -26,6 +26,7 @@ import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.boogu
|
||||
import comfy.text_encoders.krea2
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
@ -1818,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Krea2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "krea2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.15,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.2
|
||||
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Krea2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect))
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "qwen_image",
|
||||
@ -2325,6 +2355,7 @@ models = [
|
||||
Boogu,
|
||||
QwenImage,
|
||||
Ideogram4,
|
||||
Krea2,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
|
||||
84
comfy/text_encoders/krea2.py
Normal file
84
comfy/text_encoders/krea2.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap.
|
||||
|
||||
K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B
|
||||
(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and
|
||||
consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor,
|
||||
so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model.
|
||||
"""
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.text_encoders.qwen3vl
|
||||
from comfy import sd1_clip
|
||||
|
||||
# tap k == hidden_states[k] (no offset).
|
||||
KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35]
|
||||
|
||||
# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix.
|
||||
KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b")
|
||||
self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template.
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
|
||||
# Krea2 conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds.
|
||||
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
|
||||
|
||||
|
||||
class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype,
|
||||
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b")
|
||||
|
||||
|
||||
class Krea2TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560)
|
||||
tok_pairs = token_weight_pairs["qwen3vl_4b"][0]
|
||||
|
||||
# Strip the system + user-opening prefix
|
||||
count_im_start = 0
|
||||
if template_end == -1:
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
|
||||
if elem == 151644 and count_im_start < 2:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
if out.shape[2] > (template_end + 3):
|
||||
if tok_pairs[template_end + 1][0] == 872: # "user"
|
||||
if tok_pairs[template_end + 2][0] == 198: # "\n"
|
||||
template_end += 3
|
||||
|
||||
out = out[:, :, template_end:]
|
||||
|
||||
b, n, seq, h = out.shape
|
||||
# Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model.
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h)
|
||||
|
||||
if "attention_mask" in extra:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask")
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Krea2TEModel_(Krea2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Krea2TEModel_
|
||||
@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def krea2_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("layers", 0)
|
||||
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
|
||||
n_txt_refiner = 2
|
||||
key_map = {}
|
||||
|
||||
def add_block(prefix_to, prefix_from):
|
||||
block_map = {
|
||||
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
|
||||
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
|
||||
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
|
||||
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
|
||||
}
|
||||
for d, c in block_map.items():
|
||||
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
|
||||
for i in range(n_txt_layerwise):
|
||||
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
|
||||
for i in range(n_txt_refiner):
|
||||
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("img_in", "first"),
|
||||
("time_embed.linear_1", "tmlp.0"),
|
||||
("time_embed.linear_2", "tmlp.2"),
|
||||
("time_mod_proj", "tproj.1"),
|
||||
("txt_in.linear_1", "txtmlp.1"),
|
||||
("txt_in.linear_2", "txtmlp.3"),
|
||||
("text_fusion.projector", "txtfusion.projector"),
|
||||
("final_layer.linear", "last.linear"),
|
||||
]
|
||||
for d, c in MAP_BASIC:
|
||||
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -969,7 +969,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.0
|
||||
comfyui-embedded-docs==0.5.4
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
Reference in New Issue
Block a user