Compare commits

..

1 Commits

Author SHA1 Message Date
2a61015582 feat: Support Krea2 (#14589) 2026-06-22 14:35:00 -07:00
11 changed files with 510 additions and 51 deletions

290
comfy/ldm/krea2/model.py Normal file
View File

@ -0,0 +1,290 @@
"""Krea 2 (K2) — single-stream MMDiT.
Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are
concatenated into one sequence and run through ``layers`` shared transformer blocks with
AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND, timestep_embedding
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention_masked
class RMSNorm(nn.Module):
"""RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered)."""
def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0
return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype)
class QKNorm(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
def forward(self, q, k):
return self.qnorm(q), self.knorm(k)
class SwiGLU(nn.Module):
def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128,
device=None, dtype=None, operations=None):
super().__init__()
mlpdim = int(2 * features / 3) * multiplier
mlpdim = multiple * ((mlpdim + multiple - 1) // multiple)
self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.down(F.silu(self.gate(x)).mul_(self.up(x)))
class Attention(nn.Module):
def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False,
device=None, dtype=None, operations=None):
super().__init__()
self.heads = heads
self.kvheads = kvheads if kvheads is not None else heads
self.headdim = dim // self.heads
self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype)
self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations)
self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, freqs=None, mask=None, transformer_options={}):
q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x)
q = rearrange(q, "B L (H D) -> B H L D", H=self.heads)
k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads)
v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads)
q, k = self.qknorm(q, k)
if freqs is not None:
q, k = apply_rope(q, k, freqs)
if self.kvheads != self.heads:
rep = self.heads // self.kvheads
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True,
transformer_options=transformer_options)
return self.wo(out * F.sigmoid(gate))
class SimpleModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0)
scale, shift = out.chunk(2, dim=1)
return scale, shift
class DoubleSharedModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device)
return out.chunk(6, dim=-1)
class TextFusionBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, mask=None, transformer_options={}):
x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options)
x = x + self.mlp(self.postnorm(x))
return x
class TextFusionTransformer(nn.Module):
def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.layerwise_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype)
self.refiner_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
def forward(self, x, mask=None, transformer_options={}):
b, l, n, d = x.shape
x = x.reshape(b * l, n, d)
for block in self.layerwise_blocks:
x = block(x.contiguous(), mask=None, transformer_options=transformer_options)
x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l)
x = self.projector(x).squeeze(-1)
for block in self.refiner_blocks:
x = block(x, mask=mask, transformer_options=transformer_options)
return x
class SingleStreamBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations)
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, vec, freqs, mask=None, transformer_options={}):
prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec)
x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options)
x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift)
return x
class LastLayer(nn.Module):
def __init__(self, features, patch, channels, device=None, dtype=None, operations=None):
super().__init__()
self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype)
self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations)
def forward(self, x, tvec):
scale, shift = self.modulation(tvec)
x = (1 + scale) * self.norm(x) + shift
return self.linear(x)
class SingleStreamDiT(nn.Module):
def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4,
layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12,
txtheads=20, txtkvheads=20, image_model=None,
device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.patch = patch
self.channels = channels
self.tdim = tdim
self.heads = heads
self.txtdim = txtdim
self.txtlayers = txtlayers
headdim = features // heads
axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)]
assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}"
self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes)
self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype)
self.blocks = nn.ModuleList([
SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(layers)
])
self.tmlp = nn.Sequential(
operations.Linear(tdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads,
device=device, dtype=dtype, operations=operations)
self.txtmlp = nn.Sequential(
RMSNorm(txtdim, device=device, dtype=dtype, operations=operations),
operations.Linear(txtdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations)
self.tproj = nn.Sequential(
nn.GELU(approximate="tanh"),
operations.Linear(features, features * 6, device=device, dtype=dtype),
)
def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
temporal = x.ndim == 5
if temporal:
b5, c5, t5, h5, w5 = x.shape
x = x.reshape(b5 * t5, c5, h5, w5)
bs, c, H_orig, W_orig = x.shape
patch = self.patch
# Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end.
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch))
H, W = x.shape[-2], x.shape[-1]
h_, w_ = H // patch, W // patch
# context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim).
context = self._unpack_context(context)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch)
img = self.first(img)
t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype))
tvec = self.tproj(t)
context = self.txtfusion(context, mask=None, transformer_options=transformer_options)
context = self.txtmlp(context)
txtlen, imglen = context.shape[1], img.shape[1]
combined = torch.cat((context, img), dim=1)
# Position ids: text at 0, image at (0, h_idx, w_idx).
device = combined.device
txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32)
imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32)
imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None]
imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :]
imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1)
pos = torch.cat((txtpos, imgpos), dim=1)
freqs = self.pe_embedder(pos)
for block in self.blocks:
combined = block(combined, tvec, freqs, None, transformer_options=transformer_options)
final = self.last(combined, t)
out = final[:, txtlen:txtlen + imglen, :]
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=h_, w=w_, ph=patch, pw=patch, c=self.channels)
out = out[:, :, :H_orig, :W_orig] # crop padding back off
if temporal:
out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2)
return out
def _unpack_context(self, context):
# context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim).
b, seq, fused = context.shape
if fused != self.txtlayers * self.txtdim:
raise ValueError(
f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} "
f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. "
f"Load the text encoder with CLIPLoader type 'krea2'."
)
return context.reshape(b, seq, self.txtlayers, self.txtdim)

View File

@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Krea2):
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:

View File

@ -58,6 +58,7 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model
import comfy.ldm.krea2.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
@ -2278,6 +2279,17 @@ class Ideogram4(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Krea2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

View File

@ -834,6 +834,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
return dit_config
if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2)
dit_config = {}
dit_config["image_model"] = "krea2"
head_dim = 128
first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2)
dit_config["features"] = first_w.shape[0]
dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2
dit_config["patch"] = 2
dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1]
dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0]
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]

View File

@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.krea2
import comfy.text_encoders.ideogram4
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
@ -1303,6 +1304,7 @@ class CLIPType(Enum):
PIXELDIT = 29
IDEOGRAM4 = 30
BOOGU = 31
KREA2 = 32
@ -1628,6 +1630,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate).
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)

View File

@ -26,6 +26,7 @@ import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4
import comfy.text_encoders.boogu
import comfy.text_encoders.krea2
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
@ -1818,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
class Krea2(supported_models_base.BASE):
unet_config = {
"image_model": "krea2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 3.0 #TODO
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Krea2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
@ -2325,6 +2355,7 @@ models = [
Boogu,
QwenImage,
Ideogram4,
Krea2,
Flux2,
Lens,
Kandinsky5Image,

View File

@ -0,0 +1,84 @@
"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap.
K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B
(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and
consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor,
so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model.
"""
import numbers
import torch
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# tap k == hidden_states[k] (no offset).
KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35]
# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix.
KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b")
self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template.
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
# Krea2 conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype,
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b")
class Krea2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560)
tok_pairs = token_weight_pairs["qwen3vl_4b"][0]
# Strip the system + user-opening prefix
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[2] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872: # "user"
if tok_pairs[template_end + 2][0] == 198: # "\n"
template_end += 3
out = out[:, :, template_end:]
b, n, seq, h = out.shape
# Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model.
out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h)
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask")
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class Krea2TEModel_(Krea2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Krea2TEModel_

View File

@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def krea2_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("layers", 0)
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
n_txt_refiner = 2
key_map = {}
def add_block(prefix_to, prefix_from):
block_map = {
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
}
for d, c in block_map.items():
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
for i in range(n_layers):
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
for i in range(n_txt_layerwise):
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
for i in range(n_txt_refiner):
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
MAP_BASIC = [
("img_in", "first"),
("time_embed.linear_1", "tmlp.0"),
("time_embed.linear_2", "tmlp.2"),
("time_mod_proj", "tproj.1"),
("txt_in.linear_1", "txtmlp.1"),
("txt_in.linear_2", "txtmlp.3"),
("text_fusion.projector", "txtfusion.projector"),
("final_layer.linear", "last.linear"),
]
for d, c in MAP_BASIC:
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)

View File

@ -163,27 +163,15 @@ class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input, resolution).
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False, "480p"): 0.007,
("dreamina-seedance-2-0-260128", True, "480p"): 0.0043,
("dreamina-seedance-2-0-260128", False, "720p"): 0.007,
("dreamina-seedance-2-0-260128", True, "720p"): 0.0043,
("dreamina-seedance-2-0-260128", False, "1080p"): 0.0077,
("dreamina-seedance-2-0-260128", True, "1080p"): 0.0047,
("dreamina-seedance-2-0-260128", False, "4k"): 0.004,
("dreamina-seedance-2-0-260128", True, "4k"): 0.0024,
("dreamina-seedance-2-0-fast-260128", False, "480p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033,
("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056,
("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033,
("dreamina-seedance-2-0-260128", False): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
}
def seedance2_price_per_1k_tokens(model_id: str, has_video_input: bool, resolution: str) -> float | None:
return SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input, resolution))
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),

View File

@ -15,6 +15,7 @@ from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS_SEEDREAM_4_0,
RECOMMENDED_PRESETS_SEEDREAM_4_5,
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME,
GetAssetResponse,
@ -39,7 +40,6 @@ from comfy_api_nodes.apis.bytedance import (
TaskVideoContentUrl,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
seedance2_price_per_1k_tokens,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -141,7 +141,7 @@ SEEDANCE2_RATIO_WH = {
"9:16": (9, 16),
"21:9": (21, 9),
}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080, "4k": 2160}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
@ -377,9 +377,9 @@ async def _seedance_virtual_library_upload_video_asset(
return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool, resolution: str):
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = seedance2_price_per_1k_tokens(model_id, has_video_input, resolution)
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
if rate is None:
return None
@ -1621,7 +1621,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
@ -1660,15 +1660,11 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
@ -1707,7 +1703,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1728,7 +1724,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
@ -1795,15 +1791,11 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$pricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$cost := $dur * $rate * $pricePer1K / 1000;
@ -1921,7 +1913,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]),
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -2018,7 +2010,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"),
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
@ -2064,19 +2056,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
$rate480 := 10044;
$rate720 := 21600;
$rate1080 := 48800;
$rate4k := 195200;
$m := widgets.model;
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$noVideoPricePer1K := $res = "4k" ? 0.00572 :
$res = "1080p" ? 0.011011 :
$contains($m, "fast") ? 0.008008 : 0.01001;
$videoPricePer1K := $res = "4k" ? 0.003432 :
$res = "1080p" ? 0.006721 :
$contains($m, "fast") ? 0.004719 : 0.006149;
$rate := $res = "4k" ? $rate4k :
$res = "1080p" ? $rate1080 :
$rate := $res = "1080p" ? $rate1080 :
$res = "720p" ? $rate720 :
$rate480;
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
@ -2272,9 +2258,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(
model_id, has_video_input=has_video_input, resolution=model["resolution"]
),
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
poll_interval=9,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))

View File

@ -969,7 +969,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),