mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-12 12:57:35 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0155ddcbe3 | |||
| 8e53f001a4 | |||
| 0a7d2ffd68 | |||
| 20e439419c | |||
| 428c323780 | |||
| 46063aa927 | |||
| b565dc7a6c |
@ -242,6 +242,7 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -373,6 +374,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -686,6 +688,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
@ -747,6 +750,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -832,6 +836,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
|
||||
old_denoised = None
|
||||
h, h_last = None, None
|
||||
@ -889,6 +894,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h, h_1, h_2 = None, None, None
|
||||
@ -1006,23 +1012,39 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0):
|
||||
|
||||
# s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps
|
||||
# noise_clip_std: clamp injected noise to +/- N stddevs (0 disables).
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
n_steps = max(1, len(sigmas) - 1)
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
|
||||
s_start = float(s_noise)
|
||||
s_end = s_start if s_noise_end is None else float(s_noise_end)
|
||||
for i in trange(n_steps, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
if noise_clip_std > 0:
|
||||
clip_val = noise_clip_std * noise.std()
|
||||
noise = noise.clamp(min=-clip_val, max=clip_val)
|
||||
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
|
||||
s_noise_i = s_start + (s_end - s_start) * t
|
||||
if s_noise_i != 1.0:
|
||||
noise = noise * s_noise_i
|
||||
x = model_sampling.noise_scaling(sigmas[i + 1], noise, x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
@ -1249,6 +1271,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
|
||||
uncond_denoised = None
|
||||
|
||||
@ -1296,6 +1319,7 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
@ -1371,6 +1395,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
@ -1504,6 +1529,7 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_er_sde_noise_scaler(x):
|
||||
@ -1574,9 +1600,10 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
@ -1645,9 +1672,10 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
@ -1713,6 +1741,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||
|
||||
|
||||
@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiDreamO1Pixel(ChromaRadiance):
|
||||
"""Pixel-space latent format for HiDream-O1.
|
||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||
"""
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
|
||||
41
comfy/ldm/hidream_o1/attention.py
Normal file
41
comfy/ldm/hidream_o1/attention.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
|
||||
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
|
||||
mask the general-purpose path would build (~500 MB at T~16K) and lets the
|
||||
gen half hit the user's preferred backend via optimized_attention.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def make_two_pass_attention(ar_len: int, transformer_options=None):
|
||||
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
|
||||
The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes.
|
||||
"""
|
||||
|
||||
def two_pass_attention(q, k, v, heads, **kwargs):
|
||||
B, H, T, D = q.shape
|
||||
|
||||
if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call
|
||||
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
elif ar_len >= T:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
|
||||
elif ar_len <= 0:
|
||||
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
else:
|
||||
out_ar = comfy.ops.scaled_dot_product_attention(
|
||||
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
|
||||
attn_mask=None, dropout_p=0.0, is_causal=True,
|
||||
)
|
||||
out_gen = optimized_attention(
|
||||
q[:, :, ar_len:], k, v, heads,
|
||||
mask=None, skip_reshape=True, skip_output_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
out = torch.cat([out_ar, out_gen], dim=2)
|
||||
|
||||
return out.transpose(1, 2).reshape(B, T, H * D)
|
||||
|
||||
return two_pass_attention
|
||||
230
comfy/ldm/hidream_o1/conditioning.py
Normal file
230
comfy/ldm/hidream_o1/conditioning.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
|
||||
|
||||
Each ref image goes through two paths: a 32x32 patchified stream concatenated
|
||||
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
|
||||
into input_ids at <|image_pad|> positions.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy.text_encoders.qwen_vl import process_qwen2vl_images
|
||||
|
||||
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_tensor)
|
||||
|
||||
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
|
||||
VIT_PATCH = 16
|
||||
VIT_MERGE = 2
|
||||
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
|
||||
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
|
||||
|
||||
|
||||
def prepare_ref_images(
|
||||
ref_images: List[torch.Tensor],
|
||||
target_h: int,
|
||||
target_w: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Build the dual-path tensors for K reference images at (target_h, target_w).
|
||||
|
||||
Returns None for K=0, else a dict with ref_patches, ref_pixel_values,
|
||||
ref_image_grid_thw, per_ref_vit_tokens, per_ref_patch_grids.
|
||||
"""
|
||||
K = len(ref_images)
|
||||
if K == 0:
|
||||
return None
|
||||
max_size = ref_max_size(max(target_h, target_w), K)
|
||||
cis = cond_image_size(K)
|
||||
|
||||
refs_t = [img[0].clamp(0, 1).permute(2, 0, 1).unsqueeze(0).contiguous().float() for img in ref_images]
|
||||
refs_t = [resize_tensor(t, max_size, PATCH_SIZE) for t in refs_t]
|
||||
|
||||
# 32-patch path.
|
||||
ref_patches_per = []
|
||||
per_ref_patch_grids = []
|
||||
for t in refs_t:
|
||||
t_norm = (t.squeeze(0) - 0.5) / 0.5 # (3, H, W) in [-1, 1]
|
||||
h_p, w_p = t_norm.shape[-2] // PATCH_SIZE, t_norm.shape[-1] // PATCH_SIZE
|
||||
per_ref_patch_grids.append((h_p, w_p))
|
||||
patches = (
|
||||
t_norm.reshape(3, h_p, PATCH_SIZE, w_p, PATCH_SIZE)
|
||||
.permute(1, 3, 0, 2, 4)
|
||||
.reshape(h_p * w_p, 3 * PATCH_SIZE * PATCH_SIZE)
|
||||
)
|
||||
ref_patches_per.append(patches)
|
||||
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
|
||||
|
||||
# ViT path.
|
||||
refs_vlm_t = []
|
||||
for t in refs_t:
|
||||
_, _, h, w = t.shape
|
||||
cond_w, cond_h = calculate_dimensions(cis, w / h)
|
||||
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
|
||||
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
|
||||
refs_vlm_t.append(comfy.utils.common_upscale(t, cond_w, cond_h, "lanczos", "disabled"))
|
||||
|
||||
pv_list, grid_list, per_ref_vit_tokens = [], [], []
|
||||
for t_v in refs_vlm_t:
|
||||
pv, grid_thw = process_qwen2vl_images(
|
||||
t_v.permute(0, 2, 3, 1),
|
||||
min_pixels=0, max_pixels=10**12,
|
||||
patch_size=VIT_PATCH, merge_size=VIT_MERGE,
|
||||
image_mean=VIT_IMAGE_MEAN, image_std=VIT_IMAGE_STD,
|
||||
)
|
||||
grid_thw = grid_thw[0]
|
||||
pv_list.append(pv.to(device=device, dtype=dtype))
|
||||
grid_list.append(grid_thw.to(device=device))
|
||||
# Post-merge token count = number of <|image_pad|> tokens this image expands to in input_ids.
|
||||
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
|
||||
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
|
||||
|
||||
return {
|
||||
"ref_patches": ref_patches,
|
||||
"ref_pixel_values": torch.cat(pv_list, dim=0),
|
||||
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
|
||||
"per_ref_vit_tokens": per_ref_vit_tokens,
|
||||
"per_ref_patch_grids": per_ref_patch_grids,
|
||||
}
|
||||
|
||||
|
||||
def build_ref_input_ids(
|
||||
text_input_ids: torch.Tensor,
|
||||
per_ref_vit_tokens: List[int],
|
||||
image_token_id: int,
|
||||
vision_start_id: int,
|
||||
vision_end_id: int,
|
||||
):
|
||||
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
|
||||
after the [im_start, user, \\n] prefix (matches original chat template).
|
||||
"""
|
||||
ids = text_input_ids[0].tolist()
|
||||
inserted = []
|
||||
for n_pad in per_ref_vit_tokens:
|
||||
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
|
||||
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
|
||||
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||
|
||||
|
||||
def build_extra_conds(
|
||||
text_input_ids: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
ref_images: List[torch.Tensor] = None,
|
||||
target_patch_size: int = 32,
|
||||
):
|
||||
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
|
||||
input_ids (with ref-vision tokens spliced in for the edit/IP path),
|
||||
position_ids (MRoPE), token_types, vinput_mask, plus the ref
|
||||
dual-path tensors when refs are provided.
|
||||
"""
|
||||
from .utils import get_rope_index_fix_point
|
||||
from comfy.text_encoders.hidream_o1 import (
|
||||
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
|
||||
)
|
||||
|
||||
if text_input_ids.dim() == 1:
|
||||
text_input_ids = text_input_ids.unsqueeze(0)
|
||||
text_input_ids = text_input_ids.long().to(noise.device)
|
||||
B = noise.shape[0]
|
||||
if text_input_ids.shape[0] == 1 and B > 1:
|
||||
text_input_ids = text_input_ids.expand(B, -1)
|
||||
|
||||
H, W = noise.shape[-2], noise.shape[-1]
|
||||
h_p, w_p = H // target_patch_size, W // target_patch_size
|
||||
image_len = h_p * w_p
|
||||
image_grid_thw_tgt = torch.tensor(
|
||||
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
|
||||
)
|
||||
|
||||
out = {}
|
||||
if ref_images:
|
||||
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
|
||||
text_input_ids = build_ref_input_ids(
|
||||
text_input_ids, ref["per_ref_vit_tokens"],
|
||||
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
|
||||
)
|
||||
new_txt_len = text_input_ids.shape[1]
|
||||
|
||||
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
|
||||
# block in the position-id stream after the noised target.
|
||||
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
|
||||
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
|
||||
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||
tgt_vision[:, 0] = VISION_START_ID
|
||||
ref_vision_blocks = []
|
||||
for rl in ref_grid_lengths:
|
||||
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
|
||||
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||
blk[:, 0] = VISION_START_ID
|
||||
ref_vision_blocks.append(blk)
|
||||
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
|
||||
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
|
||||
total_ref_patches_len = sum(ref_grid_lengths)
|
||||
total_len = new_txt_len + image_len + total_ref_patches_len
|
||||
|
||||
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
|
||||
K = len(ref_images)
|
||||
igthw_cond = ref["ref_image_grid_thw"].clone()
|
||||
igthw_cond[:, 1] //= 2
|
||||
igthw_cond[:, 2] //= 2
|
||||
image_grid_thw_ref = torch.tensor(
|
||||
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
|
||||
dtype=torch.long, device=text_input_ids.device,
|
||||
)
|
||||
igthw_all = torch.cat([
|
||||
igthw_cond.to(text_input_ids.device),
|
||||
image_grid_thw_tgt,
|
||||
image_grid_thw_ref,
|
||||
], dim=0)
|
||||
position_ids, _ = get_rope_index_fix_point(
|
||||
spatial_merge_size=1,
|
||||
image_token_id=IMAGE_TOKEN_ID,
|
||||
vision_start_token_id=VISION_START_ID,
|
||||
input_ids=input_ids_pad, image_grid_thw=igthw_all,
|
||||
attention_mask=None,
|
||||
skip_vision_start_token=[0] * K + [1] + [1] * K,
|
||||
fix_point=4096,
|
||||
)
|
||||
|
||||
# tms + target_image + ref_patches are all gen.
|
||||
tms_pos = new_txt_len - 1
|
||||
ar_len = tms_pos
|
||||
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
|
||||
token_types[:, tms_pos:] = 1
|
||||
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
|
||||
vinput_mask[:, new_txt_len:] = True
|
||||
|
||||
# Leading batch dim sidesteps CONDRegular.process_cond's repeat_to_batch_size truncation
|
||||
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
|
||||
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
|
||||
out["ref_patches"] = ref["ref_patches"]
|
||||
else:
|
||||
# T2I: text + noised target only, vision_start replaces the first image token
|
||||
txt_len = text_input_ids.shape[1]
|
||||
total_len = txt_len + image_len
|
||||
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
|
||||
dtype=text_input_ids.dtype, device=text_input_ids.device)
|
||||
vision_tokens[:, 0] = VISION_START_ID
|
||||
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
|
||||
position_ids, _ = get_rope_index_fix_point(
|
||||
spatial_merge_size=1,
|
||||
image_token_id=IMAGE_TOKEN_ID,
|
||||
vision_start_token_id=VISION_START_ID,
|
||||
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
|
||||
attention_mask=None,
|
||||
skip_vision_start_token=[1],
|
||||
)
|
||||
ar_len = txt_len - 1
|
||||
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
|
||||
token_types[:, ar_len:] = 1
|
||||
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
|
||||
vinput_mask[:, txt_len:] = True
|
||||
|
||||
out["input_ids"] = text_input_ids
|
||||
out["position_ids"] = position_ids[:, 0].unsqueeze(0) # Collapse position_ids batch and add a leading dim so CONDRegular's batch-resize doesn't truncate the 3-axis MRoPE dim
|
||||
out["token_types"] = token_types
|
||||
out["vinput_mask"] = vinput_mask
|
||||
out["ar_len"] = ar_len
|
||||
return out
|
||||
306
comfy/ldm/hidream_o1/model.py
Normal file
306
comfy/ldm/hidream_o1/model.py
Normal file
@ -0,0 +1,306 @@
|
||||
"""HiDream-O1-Image transformer.
|
||||
|
||||
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
|
||||
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
|
||||
processes a unified text+image sequence, and 32x32 patch embed/unembed
|
||||
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
|
||||
mergers go unused — their weights are dropped at load.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.text_encoders.llama import Llama2_
|
||||
from comfy.text_encoders.qwen35 import Qwen35VisionModel
|
||||
|
||||
from .attention import make_two_pass_attention
|
||||
|
||||
|
||||
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
|
||||
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
|
||||
PATCH_SIZE = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiDreamO1TextConfig:
|
||||
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 12288
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 128
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 5000000.0
|
||||
rope_scale: Optional[float] = None
|
||||
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
|
||||
interleaved_mrope: bool = True
|
||||
transformer_type: str = "llama"
|
||||
rms_norm_add: bool = False
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
q_norm: str = "gemma3"
|
||||
k_norm: str = "gemma3"
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
|
||||
|
||||
|
||||
QWEN3VL_VISION_DEFAULTS = dict(
|
||||
hidden_size=1152,
|
||||
num_heads=16,
|
||||
intermediate_size=4304,
|
||||
depth=27,
|
||||
patch_size=16,
|
||||
temporal_patch_size=2,
|
||||
in_channels=3,
|
||||
spatial_merge_size=2,
|
||||
num_position_embeddings=2304,
|
||||
deepstack_visual_indexes=(8, 16, 24),
|
||||
out_hidden_size=4096, # final merger projects directly into LLM hidden
|
||||
)
|
||||
|
||||
|
||||
class BottleneckPatchEmbed(nn.Module):
|
||||
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
|
||||
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
|
||||
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj2(self.proj1(x))
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
|
||||
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class HiDreamO1Transformer(nn.Module):
|
||||
"""HiDream-O1 unified pixel-level transformer."""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
|
||||
text_config_overrides=None, vision_config_overrides=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
|
||||
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
|
||||
if vision_config_overrides:
|
||||
vision_cfg.update(vision_config_overrides)
|
||||
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
|
||||
|
||||
self.text_config = text_cfg
|
||||
self.vision_config = vision_cfg
|
||||
self.hidden_size = text_cfg.hidden_size
|
||||
self.patch_size = PATCH_SIZE
|
||||
self.in_channels = 3
|
||||
self.tms_token_id = TMS_TOKEN_ID
|
||||
|
||||
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
|
||||
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
|
||||
self.t_embedder1 = TimestepEmbedder(
|
||||
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.x_embedder = BottleneckPatchEmbed(
|
||||
patch_size=self.patch_size, in_chans=self.in_channels,
|
||||
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
|
||||
bias=True, device=device, dtype=dtype, ops=operations,
|
||||
)
|
||||
self.final_layer2 = FinalLayer(
|
||||
text_cfg.hidden_size, patch_size=self.patch_size,
|
||||
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
|
||||
)
|
||||
|
||||
self._visual_cache = None
|
||||
self._kv_cache_entries = []
|
||||
|
||||
def clear_kv_cache(self):
|
||||
self._kv_cache_entries = []
|
||||
self._visual_cache = None
|
||||
|
||||
def forward(self, x, timesteps, context=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timesteps, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None,
|
||||
vinput_mask=None, ar_len=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs):
|
||||
"""Returns flow-match velocity (x - x_pred) / sigma"""
|
||||
|
||||
if input_ids is None or position_ids is None:
|
||||
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
|
||||
|
||||
B, _, H, W = x.shape
|
||||
h_p, w_p = H // self.patch_size, W // self.patch_size
|
||||
tgt_image_len = h_p * w_p
|
||||
|
||||
z = einops.rearrange(
|
||||
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
|
||||
p1=self.patch_size, p2=self.patch_size,
|
||||
)
|
||||
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
|
||||
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
|
||||
|
||||
if ref_pixel_values is not None and ref_image_grid_thw is not None:
|
||||
# ViT output is constant across sampling steps within a generation
|
||||
# identity-key by the input tensor so refs don't recompute every step.
|
||||
cached = self._visual_cache
|
||||
if cached is not None and cached[0] is ref_pixel_values:
|
||||
image_embeds = cached[1]
|
||||
else:
|
||||
ref_pv = ref_pixel_values.to(inputs_embeds.device)
|
||||
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
|
||||
# extra_conds wraps with a leading batch dim; refs are model-level so [0] always recovers them.
|
||||
if ref_pv.dim() == 3:
|
||||
ref_pv = ref_pv[0]
|
||||
if ref_grid.dim() == 3:
|
||||
ref_grid = ref_grid[0]
|
||||
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
|
||||
self._visual_cache = (ref_pixel_values, image_embeds)
|
||||
# image_pad positions identical across batch (input_ids shared cond/uncond).
|
||||
image_idx = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
|
||||
if image_idx.shape[0] != image_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
f"Image-token count {image_idx.shape[0]} != ViT output count "
|
||||
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
|
||||
)
|
||||
inputs_embeds[:, image_idx] = image_embeds.unsqueeze(0).expand(B, -1, -1)
|
||||
|
||||
sigma = timesteps.float() / 1000.0
|
||||
t_pixeldit = 1.0 - sigma
|
||||
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
|
||||
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
|
||||
|
||||
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
|
||||
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
|
||||
|
||||
# extra_conds stores position_ids as (1, 3, T); process_cond repeats dim 0 to B. Take row 0.
|
||||
freqs_cis = self.language_model.compute_freqs_cis(position_ids[0].to(x.device), x.device)
|
||||
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
|
||||
|
||||
two_pass_attn = make_two_pass_attention(ar_len, transformer_options=transformer_options)
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.language_model.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
|
||||
# Cache prefix K/V across steps. Key includes input_ids (prompt), ref_id
|
||||
# (refs scatter into inputs_embeds), and position_ids (RoPE baked into cached K).
|
||||
can_cache = not blocks_replace and ar_len > 0
|
||||
cache_len = ar_len if can_cache else 0
|
||||
ref_id = id(ref_pixel_values) if ref_pixel_values is not None else None
|
||||
pos_ids_key = position_ids[..., :cache_len] if can_cache else position_ids
|
||||
cache_entries = self._kv_cache_entries
|
||||
# Drop stale entries from a previous device (model was unloaded and reloaded).
|
||||
if cache_entries and cache_entries[0]["input_ids"].device != input_ids.device:
|
||||
cache_entries = []
|
||||
self._kv_cache_entries = []
|
||||
kv_cache = None
|
||||
if can_cache:
|
||||
for entry in cache_entries:
|
||||
ck = entry["input_ids"]
|
||||
ep = entry["position_ids"]
|
||||
if (entry["cache_len"] == cache_len
|
||||
and ck.shape == input_ids.shape and torch.equal(ck, input_ids)
|
||||
and entry["ref_id"] == ref_id
|
||||
and ep.shape == pos_ids_key.shape and torch.equal(ep, pos_ids_key)):
|
||||
kv_cache = entry
|
||||
break
|
||||
|
||||
if kv_cache is not None:
|
||||
# Hot path: project Q/K/V only for fresh positions; past_key_value prepends cached AR K/V.
|
||||
hidden_states = inputs_embeds[:, cache_len:]
|
||||
sliced_freqs = tuple(t[..., cache_len:, :] for t in freqs_cis)
|
||||
for i, layer in enumerate(self.language_model.layers):
|
||||
transformer_options["block_index"] = i
|
||||
K_i, V_i = kv_cache["kv"][i]
|
||||
hidden_states, _ = layer(
|
||||
x=hidden_states, attention_mask=None, freqs_cis=sliced_freqs, optimized_attention=two_pass_attn,
|
||||
past_key_value=(K_i, V_i, cache_len),
|
||||
)
|
||||
else:
|
||||
# Cold path: run full sequence; if cacheable, snapshot K/V at AR positions.
|
||||
snapshots = [] if can_cache else None
|
||||
past_kv_cold = () if can_cache else None
|
||||
hidden_states = inputs_embeds
|
||||
for i, layer in enumerate(self.language_model.layers):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args, _layer=layer):
|
||||
out = {}
|
||||
out["x"], _ = _layer(
|
||||
x=args["x"], attention_mask=args.get("attention_mask"),
|
||||
freqs_cis=args["freqs_cis"], optimized_attention=args["optimized_attention"],
|
||||
past_key_value=None,
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{"x": hidden_states, "attention_mask": None,
|
||||
"freqs_cis": freqs_cis, "optimized_attention": two_pass_attn,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
hidden_states = out["x"]
|
||||
else:
|
||||
hidden_states, present_kv = layer(
|
||||
x=hidden_states, attention_mask=None,
|
||||
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
|
||||
past_key_value=past_kv_cold,
|
||||
)
|
||||
if snapshots is not None:
|
||||
K, V, _ = present_kv
|
||||
snapshots.append((K[:, :, :cache_len].contiguous(),
|
||||
V[:, :, :cache_len].contiguous()))
|
||||
if snapshots is not None:
|
||||
# Cap at 2 entries (cond + uncond). Multi-cond workflows LRU-evict.
|
||||
new_entry = {
|
||||
"input_ids": input_ids.clone(),
|
||||
"cache_len": cache_len,
|
||||
"kv": snapshots,
|
||||
"ref_id": ref_id,
|
||||
"position_ids": pos_ids_key.clone(),
|
||||
}
|
||||
self._kv_cache_entries = (cache_entries + [new_entry])[-2:]
|
||||
|
||||
if self.language_model.norm is not None:
|
||||
hidden_states = self.language_model.norm(hidden_states)
|
||||
|
||||
# Slice target-image positions before the final projection so the Linear only runs on tgt_image_len tokens.
|
||||
# In the hot path hidden_states starts at original position cache_len, so masks/indices shift by cache_len.
|
||||
sliced_offset = cache_len if kv_cache is not None else 0
|
||||
if vinput_mask is not None:
|
||||
vmask = vinput_mask.to(x.device).bool()
|
||||
if sliced_offset > 0:
|
||||
vmask = vmask[:, sliced_offset:]
|
||||
target_hidden = hidden_states[vmask].view(B, -1, hidden_states.shape[-1])[:, :tgt_image_len]
|
||||
else:
|
||||
txt_seq_len = input_ids.shape[1]
|
||||
start = txt_seq_len - sliced_offset
|
||||
target_hidden = hidden_states[:, start:start + tgt_image_len]
|
||||
x_pred_tgt = self.final_layer2(target_hidden)
|
||||
|
||||
# fp32 final subtraction, bf16 here noticeably degrades samples.
|
||||
x_pred_img = einops.rearrange(
|
||||
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
|
||||
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
|
||||
)
|
||||
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)
|
||||
173
comfy/ldm/hidream_o1/utils.py
Normal file
173
comfy/ldm/hidream_o1/utils.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
|
||||
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
|
||||
lets the target image and patchified ref images share spatial RoPE positions
|
||||
despite living at different sequence indices — same 2D image plane.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
PATCH_SIZE = 32
|
||||
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
|
||||
|
||||
|
||||
def resize_tensor(img_t, image_size, patch_size=16):
|
||||
"""img_t: (1, 3, H, W) float [0, 1]. Fit to image_size**2 area, patch-aligned, center-cropped."""
|
||||
|
||||
while min(img_t.shape[-2], img_t.shape[-1]) >= 2 * image_size: # Pre-halves with 2x2 box averaging while the image is still very large
|
||||
img_t = torch.nn.functional.avg_pool2d(img_t, kernel_size=2, stride=2)
|
||||
|
||||
_, _, height, width = img_t.shape
|
||||
m = patch_size
|
||||
s_max = image_size * image_size
|
||||
scale = math.sqrt(s_max / (width * height))
|
||||
|
||||
candidates = [
|
||||
(round(width * scale) // m * m, round(height * scale) // m * m),
|
||||
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
|
||||
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
|
||||
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
|
||||
]
|
||||
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
|
||||
new_size = candidates[-1]
|
||||
for c in candidates:
|
||||
if c[0] * c[1] <= s_max:
|
||||
new_size = c
|
||||
break
|
||||
|
||||
new_w, new_h = new_size
|
||||
s1 = width / new_w
|
||||
s2 = height / new_h
|
||||
if s1 < s2:
|
||||
resize_w, resize_h = new_w, round(height / s1)
|
||||
else:
|
||||
resize_w, resize_h = round(width / s2), new_h
|
||||
img_t = torch.nn.functional.interpolate(img_t, size=(resize_h, resize_w), mode="bicubic")
|
||||
top = (resize_h - new_h) // 2
|
||||
left = (resize_w - new_w) // 2
|
||||
return img_t[..., top:top + new_h, left:left + new_w]
|
||||
|
||||
|
||||
def calculate_dimensions(max_size, ratio):
|
||||
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
|
||||
width = math.sqrt(max_size * max_size * ratio)
|
||||
height = width / ratio
|
||||
width = int(width / 32) * 32
|
||||
height = int(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
|
||||
def ref_max_size(target_max_dim, k):
|
||||
"""K-dependent ref-image max dim before patchifying."""
|
||||
if k == 1:
|
||||
return target_max_dim
|
||||
if k == 2:
|
||||
return target_max_dim * 48 // 64
|
||||
if k <= 4:
|
||||
return target_max_dim // 2
|
||||
if k <= 8:
|
||||
return target_max_dim * 24 // 64
|
||||
return target_max_dim // 4
|
||||
|
||||
|
||||
def cond_image_size(k):
|
||||
"""K-dependent ViT-side image size."""
|
||||
if k <= 4:
|
||||
return CONDITION_IMAGE_SIZE
|
||||
if k <= 8:
|
||||
return CONDITION_IMAGE_SIZE * 48 // 64
|
||||
return CONDITION_IMAGE_SIZE // 2
|
||||
|
||||
|
||||
def get_rope_index_fix_point(
|
||||
spatial_merge_size: int,
|
||||
image_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
skip_vision_start_token=None,
|
||||
fix_point: int = 4096,
|
||||
):
|
||||
mrope_position_deltas = []
|
||||
if input_ids is not None and image_grid_thw is not None:
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
position_ids = torch.ones(
|
||||
3, input_ids.shape[0], input_ids.shape[1],
|
||||
dtype=input_ids.dtype, device=input_ids.device,
|
||||
)
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids_b in enumerate(total_input_ids):
|
||||
fp = fix_point
|
||||
image_index = 0
|
||||
input_ids_b = input_ids_b[attention_mask[i] == 1]
|
||||
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_ids_b[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
input_tokens = input_ids_b.tolist()
|
||||
llm_pos_ids_list = []
|
||||
st = 0
|
||||
remain_images = image_nums
|
||||
for _ in range(image_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed = len(input_tokens) + 1
|
||||
t = image_grid_thw[image_index][0]
|
||||
h = image_grid_thw[image_index][1]
|
||||
w = image_grid_thw[image_index][2]
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
llm_grid_t = t.item()
|
||||
llm_grid_h = h.item() // spatial_merge_size
|
||||
llm_grid_w = w.item() // spatial_merge_size
|
||||
text_len = ed - st
|
||||
text_len -= skip_vision_start_token[image_index - 1]
|
||||
text_len = max(0, text_len)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
||||
|
||||
if skip_vision_start_token[image_index - 1]:
|
||||
if fp > 0:
|
||||
fp = fp - st_idx
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fp + st_idx)
|
||||
fp = 0
|
||||
else:
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
||||
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
||||
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||
else:
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
|
||||
)
|
||||
mrope_position_deltas = torch.zeros(
|
||||
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
|
||||
)
|
||||
return position_ids, mrope_position_deltas
|
||||
@ -97,12 +97,14 @@ def load_lora(lora, to_load, log_missing=True):
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
prefix_set = set()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||
if tp > 0 and not k.startswith("clip_"):
|
||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||
prefix_set.add(k.split('.')[0])
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
@ -163,6 +165,13 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
if len(prefix_set) == 1:
|
||||
full_prefix = "{}.transformer.model.".format(next(iter(prefix_set))) # kohya anima and maybe other single TE models that use a single llama arch based te
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith(full_prefix):
|
||||
l_key = k[len(full_prefix):-len(".weight")]
|
||||
key_map["lora_te_{}".format(l_key.replace(".", "_"))] = k
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
|
||||
@ -58,6 +58,8 @@ import comfy.ldm.cogvideo.model
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1674,6 +1676,32 @@ class HiDream(BaseModel):
|
||||
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
|
||||
return out
|
||||
|
||||
class HiDreamO1(BaseModel):
|
||||
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
|
||||
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
|
||||
PATCH_SIZE = 32
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
text_input_ids = kwargs.get("text_input_ids", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
if text_input_ids is None or noise is None:
|
||||
return out
|
||||
|
||||
conds = build_extra_conds(
|
||||
text_input_ids, noise,
|
||||
ref_images=kwargs.get("reference_latents", None),
|
||||
target_patch_size=self.PATCH_SIZE,
|
||||
)
|
||||
for k, v in conds.items():
|
||||
# ar_len is a Python int (precomputed to avoid a GPU sync in forward).
|
||||
cls = comfy.conds.CONDConstant if k == "ar_len" else comfy.conds.CONDRegular
|
||||
out[k] = cls(v)
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
|
||||
@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||
return {"image_model": "hidream_o1"}
|
||||
|
||||
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hidream"
|
||||
|
||||
@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class LazyCastingQuantizedParam:
|
||||
def __init__(self, model, key):
|
||||
self.model = model
|
||||
self.key = key
|
||||
self.cpu_state_dict = None
|
||||
|
||||
def state_dict_tensor(self, state_dict_key):
|
||||
if self.cpu_state_dict is None:
|
||||
weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True)
|
||||
self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()}
|
||||
return self.cpu_state_dict[state_dict_key]
|
||||
|
||||
|
||||
class LazyCastingParamPiece(torch.nn.Parameter):
|
||||
def __new__(cls, caster, state_dict_key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, caster, state_dict_key, tensor):
|
||||
self.caster = caster
|
||||
self.state_dict_key = state_dict_key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
caster = self.caster
|
||||
del self.caster
|
||||
return caster.state_dict_tensor(self.state_dict_key)
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -1463,20 +1494,37 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
original_state_dict = self.model.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
keys = list(original_state_dict)
|
||||
while len(keys) > 0:
|
||||
k = keys.pop(0)
|
||||
v = original_state_dict[k]
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
unet_state_dict[k] = v
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||
qt_state_dict = weight.state_dict(k)
|
||||
caster = LazyCastingQuantizedParam(self, key)
|
||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||
if group_key in keys:
|
||||
keys.remove(group_key)
|
||||
unet_state_dict.pop(group_key, "")
|
||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||
continue
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -93,7 +93,8 @@ class CONST:
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
s = getattr(self, "noise_scale", 1.0)
|
||||
return sigma * (s * noise) + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
sigma = reshape_sigma(sigma, latent.ndim)
|
||||
@ -288,7 +289,11 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
|
||||
self.set_noise_scale(sampling_settings.get("noise_scale", 1.0))
|
||||
self.set_parameters(
|
||||
shift=sampling_settings.get("shift", 1.0),
|
||||
multiplier=sampling_settings.get("multiplier", 1000),
|
||||
)
|
||||
|
||||
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
|
||||
self.shift = shift
|
||||
@ -296,6 +301,9 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
def set_noise_scale(self, noise_scale):
|
||||
self.noise_scale = float(noise_scale)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@ -1285,7 +1285,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
|
||||
@ -239,7 +239,8 @@ class CLIP:
|
||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
te_disable_dynamic = disable_dynamic or getattr(self.cond_stage_model, "disable_offload", False)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if te_disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
@ -776,6 +777,7 @@ class VAE:
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.disable_offload = True
|
||||
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||
sample_rate = 16000
|
||||
if sample_rate == 16000:
|
||||
|
||||
@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.hidream_o1
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1431,6 +1432,50 @@ class HiDream(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
|
||||
class HiDreamO1(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream_o1",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
"noise_scale": 8.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.HiDreamO1Pixel
|
||||
memory_usage_factor = 0.6
|
||||
# fp16 not supported: LM MLP down_proj activations fp16 overflow, causing NaNs
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.HiDreamO1(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
|
||||
for key in list(state_dict.keys()):
|
||||
if "visual.deepstack_merger_list" in key:
|
||||
del state_dict[key]
|
||||
return state_dict
|
||||
|
||||
def process_vae_state_dict(self, state_dict):
|
||||
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
|
||||
return {"pixel_space_vae": torch.tensor(1.0)}
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
|
||||
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
|
||||
comfy.text_encoders.hidream_o1.HiDreamO1TE,
|
||||
)
|
||||
|
||||
class Chroma(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "chroma",
|
||||
@ -2018,6 +2063,7 @@ models = [
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
|
||||
119
comfy/text_encoders/hidream_o1.py
Normal file
119
comfy/text_encoders/hidream_o1.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""HiDream-O1-Image tokenizer-only text encoder.
|
||||
|
||||
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
|
||||
module just tokenizes the prompt into text_input_ids and emits them as
|
||||
conditioning. Position ids / token_types / vinput_mask depend on target H/W
|
||||
and are built later in model_base.HiDreamO1.extra_conds.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
# Qwen3-VL special tokens
|
||||
IM_START_ID = 151644
|
||||
IM_END_ID = 151645
|
||||
ASSISTANT_ID = 77091
|
||||
USER_ID = 872
|
||||
NEWLINE_ID = 198
|
||||
VISION_START_ID = 151652
|
||||
VISION_END_ID = 151653
|
||||
IMAGE_TOKEN_ID = 151655
|
||||
VIDEO_TOKEN_ID = 151656
|
||||
# HiDream-O1-specific tokens
|
||||
BOI_TOKEN_ID = 151669
|
||||
BOR_TOKEN_ID = 151670
|
||||
EOR_TOKEN_ID = 151671
|
||||
BOT_TOKEN_ID = 151672
|
||||
TMS_TOKEN_ID = 151673
|
||||
|
||||
|
||||
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
|
||||
)
|
||||
super().__init__(
|
||||
tokenizer_path,
|
||||
pad_with_end=False,
|
||||
embedding_size=4096,
|
||||
embedding_key="hidream_o1",
|
||||
tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_token=151643,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
|
||||
|
||||
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
|
||||
Image tokens get spliced in at sample time once target H/W is known.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
tokenizer_data=tokenizer_data,
|
||||
name="hidream_o1",
|
||||
tokenizer=HiDreamO1QwenTokenizer,
|
||||
)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
text_tokens_dict = super().tokenize_with_weights(
|
||||
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||
)
|
||||
text_tuples = text_tokens_dict["hidream_o1"][0]
|
||||
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
|
||||
|
||||
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
|
||||
def tok(tid):
|
||||
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
|
||||
|
||||
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
|
||||
suffix = [
|
||||
tok(IM_END_ID), tok(NEWLINE_ID),
|
||||
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
|
||||
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
|
||||
]
|
||||
full = prefix + list(text_tuples) + suffix
|
||||
return {"hidream_o1": [full]}
|
||||
|
||||
|
||||
class HiDreamO1TE(torch.nn.Module):
|
||||
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in diffusion_model does the actual encoding."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = {torch.float32}
|
||||
self.disable_offload = True # skips dynamic VRAM management for this zero-parameter module
|
||||
self.device = torch.device("cpu") if device is None else torch.device(device)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tok_pairs = token_weight_pairs["hidream_o1"][0]
|
||||
ids = [int(t[0]) for t in tok_pairs]
|
||||
input_ids = torch.tensor([ids], dtype=torch.long)
|
||||
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
|
||||
# plumbing; the model reads text_input_ids out of `extra` instead.
|
||||
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
|
||||
extra = {"text_input_ids": input_ids}
|
||||
return cross_attn, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
return []
|
||||
|
||||
def get_sd(self):
|
||||
return {}
|
||||
|
||||
def reset_clip_options(self):
|
||||
pass
|
||||
|
||||
def set_clip_options(self, options):
|
||||
pass
|
||||
@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
|
||||
@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
|
||||
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
|
||||
freqs_inter = freqs[0].clone()
|
||||
for axis_idx, offset in ((1, 1), (2, 2)):
|
||||
length = rope_dims[axis_idx] * 3
|
||||
idx = slice(offset, length, 3)
|
||||
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
|
||||
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
|
||||
cos = emb.cos().unsqueeze(0)
|
||||
sin = emb.sin().unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||
|
||||
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
|
||||
@ -451,9 +451,8 @@ class Qwen35VisionPatchEmbed(nn.Module):
|
||||
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
target_dtype = self.proj.weight.dtype
|
||||
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||
return self.proj(x).view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class Qwen35VisionMLP(nn.Module):
|
||||
@ -651,7 +650,7 @@ class Qwen35VisionModel(nn.Module):
|
||||
x = self.patch_embed(x)
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||
x = x + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device)
|
||||
seq_len = x.shape[0]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
@ -198,6 +198,62 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_1K = [
|
||||
("(1K) 1024x1024 (1:1)", 1024, 1024),
|
||||
("(1K) 864x1152 (3:4)", 864, 1152),
|
||||
("(1K) 1152x864 (4:3)", 1152, 864),
|
||||
("(1K) 1312x736 (16:9)", 1312, 736),
|
||||
("(1K) 736x1312 (9:16)", 736, 1312),
|
||||
("(1K) 832x1248 (2:3)", 832, 1248),
|
||||
("(1K) 1248x832 (3:2)", 1248, 832),
|
||||
("(1K) 1568x672 (21:9)", 1568, 672),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_2K = [
|
||||
("(2K) 2048x2048 (1:1)", 2048, 2048),
|
||||
("(2K) 1728x2304 (3:4)", 1728, 2304),
|
||||
("(2K) 2304x1728 (4:3)", 2304, 1728),
|
||||
("(2K) 2848x1600 (16:9)", 2848, 1600),
|
||||
("(2K) 1600x2848 (9:16)", 1600, 2848),
|
||||
("(2K) 1664x2496 (2:3)", 1664, 2496),
|
||||
("(2K) 2496x1664 (3:2)", 2496, 1664),
|
||||
("(2K) 3136x1344 (21:9)", 3136, 1344),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_3K = [
|
||||
("(3K) 3072x3072 (1:1)", 3072, 3072),
|
||||
("(3K) 2592x3456 (3:4)", 2592, 3456),
|
||||
("(3K) 3456x2592 (4:3)", 3456, 2592),
|
||||
("(3K) 4096x2304 (16:9)", 4096, 2304),
|
||||
("(3K) 2304x4096 (9:16)", 2304, 4096),
|
||||
("(3K) 2496x3744 (2:3)", 2496, 3744),
|
||||
("(3K) 3744x2496 (3:2)", 3744, 2496),
|
||||
("(3K) 4704x2016 (21:9)", 4704, 2016),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_4K = [
|
||||
("(4K) 4096x4096 (1:1)", 4096, 4096),
|
||||
("(4K) 3520x4704 (3:4)", 3520, 4704),
|
||||
("(4K) 4704x3520 (4:3)", 4704, 3520),
|
||||
("(4K) 5504x3040 (16:9)", 5504, 3040),
|
||||
("(4K) 3040x5504 (9:16)", 3040, 5504),
|
||||
("(4K) 3328x4992 (2:3)", 3328, 4992),
|
||||
("(4K) 4992x3328 (3:2)", 4992, 3328),
|
||||
("(4K) 6240x2656 (21:9)", 6240, 2656),
|
||||
]
|
||||
|
||||
_CUSTOM_PRESET = [("Custom", None, None)]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_5_LITE = (
|
||||
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_3K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_5 = (
|
||||
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_0 = (
|
||||
_PRESETS_SEEDREAM_1K + _PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model and output resolution.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {
|
||||
|
||||
@ -596,6 +596,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
|
||||
expr=cls.PRICE_BADGE_EXPR,
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -674,6 +675,175 @@ class Flux2MaxImageNode(Flux2ProImageNode):
|
||||
"""
|
||||
|
||||
|
||||
_FLUX2_MODEL_ENDPOINTS = {
|
||||
"Flux.2 [pro]": "/proxy/bfl/flux-2-pro/generate",
|
||||
"Flux.2 [max]": "/proxy/bfl/flux-2-max/generate",
|
||||
}
|
||||
|
||||
|
||||
def _flux2_model_inputs():
|
||||
return [
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=768,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 9)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s) for image-to-image generation. Up to 8 images.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Flux2ImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Flux2ImageNode",
|
||||
display_name="Flux.2 Image",
|
||||
category="api node/image/BFL",
|
||||
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or edit",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Flux.2 [pro]", _flux2_model_inputs()),
|
||||
IO.DynamicCombo.Option("Flux.2 [max]", _flux2_model_inputs()),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.width", "model.height"],
|
||||
input_groups=["model.images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isMax := widgets.model = "flux.2 [max]";
|
||||
$MP := 1024 * 1024;
|
||||
$w := $lookup(widgets, "model.width");
|
||||
$h := $lookup(widgets, "model.height");
|
||||
$outMP := $max([1, $floor((($w * $h) + $MP - 1) / $MP)]);
|
||||
$outputCost := $isMax
|
||||
? (0.07 + 0.03 * ($outMP - 1))
|
||||
: (0.03 + 0.015 * ($outMP - 1));
|
||||
$refMin := $isMax ? 0.03 : 0.015;
|
||||
$refMax := $isMax ? 0.24 : 0.12;
|
||||
$hasRefs := $lookup(inputGroups, "model.images") > 0;
|
||||
$hasRefs
|
||||
? {
|
||||
"type": "range_usd",
|
||||
"min_usd": $outputCost + $refMin,
|
||||
"max_usd": $outputCost + $refMax,
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
: {"type": "usd", "usd": $outputCost}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
model_choice = model["model"]
|
||||
endpoint = _FLUX2_MODEL_ENDPOINTS[model_choice]
|
||||
width = model["width"]
|
||||
height = model["height"]
|
||||
images_dict = model.get("images") or {}
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
if n_images > 8:
|
||||
raise ValueError("The current maximum number of supported images is 8.")
|
||||
|
||||
flat_tensors: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat_tensors.append(tensor)
|
||||
|
||||
reference_images: dict[str, str] = {}
|
||||
for idx, tensor in enumerate(flat_tensors):
|
||||
key_name = f"input_image_{idx + 1}" if idx else "input_image"
|
||||
reference_images[key_name] = tensor_to_base64_string(tensor, total_pixels=2048 * 2048)
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=endpoint, method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=Flux2ProGenerateRequest(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
**reference_images,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -685,6 +855,7 @@ class BFLExtension(ComfyExtension):
|
||||
FluxProFillNode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
Flux2ImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,9 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_0,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_5,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
@ -68,6 +71,12 @@ SEEDREAM_MODELS = {
|
||||
"seedream-4-0-250828": "seedream-4-0-250828",
|
||||
}
|
||||
|
||||
SEEDREAM_PRESETS = {
|
||||
"seedream-5-0-260128": RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
|
||||
"seedream-4-5-251128": RECOMMENDED_PRESETS_SEEDREAM_4_5,
|
||||
"seedream-4-0-250828": RECOMMENDED_PRESETS_SEEDREAM_4_0,
|
||||
}
|
||||
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
@ -562,6 +571,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -651,6 +661,226 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
|
||||
|
||||
|
||||
def _seedream_model_inputs(*, max_ref_images: int, presets: list):
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"size_preset",
|
||||
options=[label for label, _, _ in presets],
|
||||
tooltip="Pick a recommended size. Select Custom to use the width and height below.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=6240,
|
||||
step=2,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4992,
|
||||
step=2,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"max_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=max_ref_images,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Maximum number of images to generate. With 1, exactly one image is produced. "
|
||||
"With >1, the model generates between 1 and max_images related images "
|
||||
"(e.g., story scenes, character variations). "
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) for image-to-image or multi-reference generation. "
|
||||
f"Up to {max_ref_images} images.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"fail_on_partial",
|
||||
default=False,
|
||||
tooltip="If enabled, abort execution if any requested images are missing or return an error.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDanceSeedreamNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNodeV2",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for creating or editing an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream 5.0 lite",
|
||||
_seedream_model_inputs(max_ref_images=14, presets=RECOMMENDED_PRESETS_SEEDREAM_5_LITE),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream-4-5-251128",
|
||||
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_5),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream-4-0-250828",
|
||||
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_0),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$price := $contains(widgets.model, "5.0 lite") ? 0.035 :
|
||||
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $price,
|
||||
"format": { "suffix":" x images/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int = 0,
|
||||
watermark: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDREAM_MODELS[model["model"]]
|
||||
presets = SEEDREAM_PRESETS[model_id]
|
||||
|
||||
size_preset = model.get("size_preset", presets[0][0])
|
||||
width = model.get("width", 2048)
|
||||
height = model.get("height", 2048)
|
||||
max_images = model.get("max_images", 1)
|
||||
sequential_image_generation = "disabled" if max_images == 1 else "auto"
|
||||
images_dict = model.get("images") or {}
|
||||
fail_on_partial = model.get("fail_on_partial", False)
|
||||
|
||||
w = h = None
|
||||
for label, tw, th in presets:
|
||||
if label == size_preset:
|
||||
w, h = tw, th
|
||||
break
|
||||
if w is None or h is None:
|
||||
w, h = width, height
|
||||
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if ("seedream-4-5" in model_id or "seedream-5-0" in model_id) and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution for the selected model is 3.68MP, but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model_id and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if out_num_pixels > 16_777_216:
|
||||
raise ValueError(
|
||||
f"Maximum image resolution for the selected model is 16.78MP, but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_input_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
max_num_of_images = 14 if model_id == "seedream-5-0-260128" else 10
|
||||
if n_input_images > max_num_of_images:
|
||||
raise ValueError(
|
||||
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
|
||||
)
|
||||
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
|
||||
raise ValueError(
|
||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||
)
|
||||
|
||||
reference_images_urls: list[str] = []
|
||||
if image_tensors:
|
||||
for tensor in image_tensors:
|
||||
validate_image_aspect_ratio(tensor, (1, 3), (3, 1))
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensors,
|
||||
max_images=n_input_images,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
response_model=ImageTaskCreationResponse,
|
||||
data=Seedream4TaskCreationRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
|
||||
if fail_on_partial and len(urls) < len(response.data):
|
||||
raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.")
|
||||
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
|
||||
|
||||
|
||||
class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -2105,6 +2335,7 @@ class ByteDanceExtension(ComfyExtension):
|
||||
return [
|
||||
ByteDanceImageNode,
|
||||
ByteDanceSeedreamNode,
|
||||
ByteDanceSeedreamNodeV2,
|
||||
ByteDanceTextToVideoNode,
|
||||
ByteDanceImageToVideoNode,
|
||||
ByteDanceFirstLastFrameNode,
|
||||
|
||||
@ -162,6 +162,61 @@ class GrokImageNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS = [
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"9:19.5",
|
||||
"19.5:9",
|
||||
"9:20",
|
||||
"20:9",
|
||||
"1:2",
|
||||
"2:1",
|
||||
]
|
||||
|
||||
|
||||
def _grok_image_edit_model_inputs(*, max_ref_images: int, with_aspect_ratio: bool):
|
||||
inputs = [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
|
||||
min=1,
|
||||
),
|
||||
tooltip=(
|
||||
"Reference image to edit."
|
||||
if max_ref_images == 1
|
||||
else f"Reference image(s) to edit. Up to {max_ref_images} images."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of edited images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
]
|
||||
if with_aspect_ratio:
|
||||
inputs.append(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS,
|
||||
tooltip="Only allowed when multiple images are connected.",
|
||||
)
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class GrokImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -256,6 +311,7 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -303,6 +359,143 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class GrokImageEditNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNodeV2",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image-quality",
|
||||
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image-pro",
|
||||
_grok_image_edit_model_inputs(max_ref_images=1, with_aspect_ratio=False),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image",
|
||||
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.resolution", "model.number_of_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isQualityModel := widgets.model = "grok-imagine-image-quality";
|
||||
$isPro := $contains(widgets.model, "pro");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$n := $lookup(widgets, "model.number_of_images");
|
||||
$rate := $isQualityModel
|
||||
? ($res = "1k" ? 0.05 : 0.07)
|
||||
: ($isPro ? 0.07 : 0.02);
|
||||
$base := $isQualityModel ? 0.01 : 0.002;
|
||||
$output := $rate * $n;
|
||||
$isPro
|
||||
? {"type":"usd","usd": $base + $output}
|
||||
: {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = model["model"]
|
||||
resolution = model["resolution"]
|
||||
number_of_images = model["number_of_images"]
|
||||
images_dict = model.get("images") or {}
|
||||
aspect_ratio = model.get("aspect_ratio", "auto")
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
if n_images < 1:
|
||||
raise ValueError("At least one image is required for editing.")
|
||||
if model_id == "grok-imagine-image-pro" and n_images > 1:
|
||||
raise ValueError("The pro model supports only 1 input image.")
|
||||
if model_id != "grok-imagine-image-pro" and n_images > 3:
|
||||
raise ValueError("A maximum of 3 input images is supported.")
|
||||
if aspect_ratio != "auto" and n_images == 1:
|
||||
raise ValueError(
|
||||
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
|
||||
)
|
||||
|
||||
flat_tensors: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat_tensors.append(tensor)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||
data=ImageEditRequest(
|
||||
model=model_id,
|
||||
images=[
|
||||
InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in flat_tensors
|
||||
],
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -737,6 +930,7 @@ class GrokExtension(ComfyExtension):
|
||||
return [
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokImageEditNodeV2,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
|
||||
@ -27,6 +27,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -372,6 +373,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -640,6 +642,316 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
def _gpt_image_shared_inputs():
|
||||
"""Inputs shared by all GPT Image models (quality + reference images + mask)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
default="low",
|
||||
options=["low", "medium", "high"],
|
||||
tooltip="Image quality, affects cost and generation time.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s) for image editing. Up to 16 images.",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
optional=True,
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced). "
|
||||
"Requires exactly one reference image.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _gpt_image_legacy_model_inputs():
|
||||
"""Per-model widget set for legacy gpt-image-1 / gpt-image-1.5 (4 base sizes, transparent bg allowed)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
tooltip="Image size.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque", "transparent"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
]
|
||||
|
||||
|
||||
class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"gpt-image-2",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
"Custom",
|
||||
],
|
||||
tooltip="Image size. Select 'Custom' to use the custom width and height.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_width",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_height",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("gpt-image-1.5", _gpt_image_legacy_model_inputs()),
|
||||
IO.DynamicCombo.Option("gpt-image-1", _gpt_image_legacy_model_inputs()),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="How many images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="not implemented yet in backend",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.quality", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"gpt-image-1": {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.042, 0.07],
|
||||
"high": [0.167, 0.25]
|
||||
},
|
||||
"gpt-image-1.5": {
|
||||
"low": [0.009, 0.02],
|
||||
"medium": [0.034, 0.062],
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.019],
|
||||
"medium": [0.041, 0.168],
|
||||
"high": [0.165, 0.67]
|
||||
}
|
||||
};
|
||||
$range := $lookup($lookup($ranges, widgets.model), $lookup(widgets, "model.quality"));
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0] * $n,
|
||||
"max_usd": $range[1] * $n,
|
||||
"format": { "suffix": "/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
model_id = model["model"]
|
||||
size = model["size"]
|
||||
background = model["background"]
|
||||
quality = model["quality"]
|
||||
custom_width = model.get("custom_width", 1024)
|
||||
custom_height = model.get("custom_height", 1024)
|
||||
|
||||
images_dict = model.get("images") or {}
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
mask = model.get("mask")
|
||||
|
||||
if mask is not None and n_images == 0:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if size == "Custom":
|
||||
if custom_width % 16 != 0 or custom_height % 16 != 0:
|
||||
raise ValueError(
|
||||
f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}"
|
||||
)
|
||||
if max(custom_width, custom_height) > 3840:
|
||||
raise ValueError(
|
||||
f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}"
|
||||
)
|
||||
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
|
||||
if ratio > 3:
|
||||
raise ValueError(
|
||||
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
|
||||
)
|
||||
total_pixels = custom_width * custom_height
|
||||
if not 655_360 <= total_pixels <= 8_294_400:
|
||||
raise ValueError(
|
||||
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
|
||||
)
|
||||
size = f"{custom_width}x{custom_height}"
|
||||
|
||||
if model_id == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model_id == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
elif model_id == "gpt-image-2":
|
||||
price_extractor = calculate_tokens_price_image_2_0
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model_id}")
|
||||
|
||||
if image_tensors:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i : i + 1] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor.unsqueeze(0))
|
||||
|
||||
files = []
|
||||
for i, single_image in enumerate(flat):
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
|
||||
if len(flat) == 1:
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if len(flat) != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
ref_image = flat[0]
|
||||
if mask.shape[1:] != ref_image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
scaled_mask = downscale_image_tensor(
|
||||
rgba_mask.unsqueeze(0), total_pixels=2048 * 2048
|
||||
).squeeze()
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
class OpenAIChatNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
@ -999,6 +1311,7 @@ class OpenAIExtension(ComfyExtension):
|
||||
OpenAIDalle2,
|
||||
OpenAIDalle3,
|
||||
OpenAIGPTImage1,
|
||||
OpenAIGPTImageNodeV2,
|
||||
OpenAIChatNode,
|
||||
OpenAIInputFiles,
|
||||
OpenAIChatConfig,
|
||||
|
||||
@ -86,6 +86,37 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
return x
|
||||
|
||||
|
||||
class SamplerLCM(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCM",
|
||||
category="sampling/samplers",
|
||||
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
|
||||
inputs=[
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
|
||||
tooltip="Per-step noise multiplier at the first step (1.0 = match training)."),
|
||||
io.Float.Input("s_noise_end", default=1.0, min=0.0, max=64.0, step=0.01,
|
||||
tooltip="Per-step noise multiplier at the last step. Set equal to s_noise for a constant schedule."),
|
||||
io.Float.Input("noise_clip_std", default=0.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Clamp per-step noise to +/- N*std. 0 disables."),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, s_noise, s_noise_end, noise_clip_std) -> io.NodeOutput:
|
||||
sampler = comfy.samplers.ksampler(
|
||||
"lcm",
|
||||
{
|
||||
"s_noise": float(s_noise),
|
||||
"s_noise_end": float(s_noise_end),
|
||||
"noise_clip_std": float(noise_clip_std),
|
||||
},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class SamplerEulerCFGpp(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
@ -114,6 +145,7 @@ class AdvancedSamplersExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SamplerLCMUpscale,
|
||||
SamplerLCM,
|
||||
SamplerEulerCFGpp,
|
||||
]
|
||||
|
||||
|
||||
256
comfy_extras/nodes_hidream_o1.py
Normal file
256
comfy_extras/nodes_hidream_o1.py
Normal file
@ -0,0 +1,256 @@
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class EmptyHiDreamO1LatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EmptyHiDreamO1LatentImage",
|
||||
display_name="Empty HiDream-O1 Latent Image",
|
||||
category="latent/image",
|
||||
description=(
|
||||
"Empty pixel-space latent for HiDream-O1-Image. The model was "
|
||||
"trained at ~4 megapixels; lower resolutions go off-distribution "
|
||||
"and quality regresses noticeably. Trained resolutions: "
|
||||
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
|
||||
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
|
||||
),
|
||||
inputs=[
|
||||
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
|
||||
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
|
||||
io.Int.Input(id="batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[io.Latent().Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
|
||||
latent = torch.zeros(
|
||||
(batch_size, 3, height, width),
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class HiDreamO1ReferenceImages(io.ComfyNode):
|
||||
"""Attach reference images to both positive and negative conditioning."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="HiDreamO1ReferenceImages",
|
||||
display_name="HiDream-O1 Reference Images",
|
||||
category="conditioning/image",
|
||||
description=(
|
||||
"Attach 1-10 reference images to conditioning, one for edit instruction"
|
||||
"or multiple for subject-driven personalization."
|
||||
),
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Autogrow.Input(
|
||||
"images",
|
||||
template=io.Autogrow.TemplateNames(
|
||||
io.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 11)],
|
||||
min=1,
|
||||
),
|
||||
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
|
||||
PATCH_SIZE = 32
|
||||
EDGE_FEATHER = 4
|
||||
|
||||
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
|
||||
SHIFTS_BY_PATTERN = {
|
||||
("single_shift", 2): [(0, 0), (16, 16)],
|
||||
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
|
||||
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
|
||||
(8, 8), (24, 8), (8, 24), (24, 24)],
|
||||
("symmetric", 2): [(-8, -8), (8, 8)],
|
||||
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
|
||||
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
|
||||
(-4, -4), (12, -4), (-4, 12), (12, 12)],
|
||||
}
|
||||
RAMP_LEVELS = {
|
||||
"2": [2],
|
||||
"4": [4],
|
||||
"ramp_2_4": [2, 4],
|
||||
"ramp_2_4_8": [2, 4, 8],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
|
||||
"""size x size Hann tile peaking at (cy, cx) within a patch."""
|
||||
half = size // 2
|
||||
yy = torch.arange(size).view(size, 1)
|
||||
xx = torch.arange(size).view(1, size)
|
||||
dy = ((yy - cy + half) % size) - half
|
||||
dx = ((xx - cx + half) % size) - half
|
||||
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="HiDreamO1PatchSeamSmoothing",
|
||||
display_name="HiDream-O1 Patch Seam Smoothing",
|
||||
category="advanced/model",
|
||||
is_experimental=True,
|
||||
description=(
|
||||
"Average the model output across multiple shifted patch-grid "
|
||||
"positions during the late portion of sampling. Cancels seams."
|
||||
),
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
|
||||
),
|
||||
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Sampling progress at which the blend turns OFF.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
id="pattern",
|
||||
options=["single_shift", "symmetric"],
|
||||
default="single_shift",
|
||||
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
id="passes",
|
||||
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
|
||||
default="2",
|
||||
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
|
||||
),
|
||||
io.Combo.Input(
|
||||
id="blend",
|
||||
options=["average", "window", "median"],
|
||||
default="average",
|
||||
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
|
||||
),
|
||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
|
||||
if strength <= 0.0 or end_percent <= start_percent:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
P = cls.PATCH_SIZE
|
||||
half = P // 2
|
||||
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
|
||||
|
||||
if blend == "window":
|
||||
window_tile_levels = [
|
||||
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
|
||||
for lst in shift_levels
|
||||
]
|
||||
else:
|
||||
window_tile_levels = [None] * len(shift_levels)
|
||||
|
||||
m = model.clone()
|
||||
model_sampling = m.get_model_object("model_sampling")
|
||||
multiplier = float(model_sampling.multiplier)
|
||||
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
|
||||
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
|
||||
|
||||
edge_ramp_cache: dict = {}
|
||||
|
||||
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
|
||||
key = (H, W, device, dtype)
|
||||
cached = edge_ramp_cache.get(key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
feather = cls.EDGE_FEATHER
|
||||
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
|
||||
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
|
||||
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
|
||||
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
|
||||
y_mask = ((ys - P) / feather).clamp(0, 1)
|
||||
x_mask = ((xs - P) / feather).clamp(0, 1)
|
||||
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
|
||||
edge_ramp_cache[key] = ramp
|
||||
return ramp
|
||||
|
||||
def smoothing_wrapper(executor, *args, **kwargs):
|
||||
x = args[0]
|
||||
t = float(args[1][0])
|
||||
pred = executor(*args, **kwargs)
|
||||
if not (end_t <= t <= start_t):
|
||||
return pred
|
||||
# Pick shift-level by sigma phase across the gated range.
|
||||
if len(shift_levels) == 1:
|
||||
level_idx = 0
|
||||
else:
|
||||
phase = (start_t - t) / max(start_t - end_t, 1e-8)
|
||||
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
|
||||
shifts = shift_levels[level_idx]
|
||||
window_tiles = window_tile_levels[level_idx]
|
||||
|
||||
preds = []
|
||||
for sy, sx in shifts:
|
||||
if sy == 0 and sx == 0:
|
||||
preds.append(pred)
|
||||
continue
|
||||
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
|
||||
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
|
||||
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
|
||||
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
|
||||
_, _, _, H, W = stacked.shape
|
||||
if blend == "window":
|
||||
N = stacked.shape[0]
|
||||
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
|
||||
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
|
||||
sum_w = w.sum(dim=0, keepdim=True)
|
||||
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
|
||||
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
|
||||
elif blend == "median":
|
||||
avg = torch.median(stacked, dim=0).values
|
||||
else:
|
||||
avg = stacked.mean(dim=0)
|
||||
|
||||
# Mask out the P-px wraparound contamination strip at each edge.
|
||||
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
|
||||
return pred * (1.0 - mask * strength) + avg * (mask * strength)
|
||||
|
||||
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class HiDreamO1Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyHiDreamO1LatentImage,
|
||||
HiDreamO1ReferenceImages,
|
||||
HiDreamO1PatchSeamSmoothing,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HiDreamO1Extension:
|
||||
return HiDreamO1Extension()
|
||||
@ -300,6 +300,29 @@ class RescaleCFG:
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return (m, )
|
||||
|
||||
class ModelNoiseScale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"noise_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 64.0, "step": 0.01,
|
||||
"tooltip": "Absolute training noise scale. For example HiDream-O1 base: 8.0, dev: 7.5."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, noise_scale):
|
||||
m = model.clone()
|
||||
original = m.model.model_sampling
|
||||
ms = type(original)(m.model.model_config)
|
||||
ms.set_parameters(shift=original.shift, multiplier=original.multiplier)
|
||||
ms.set_noise_scale(noise_scale)
|
||||
m.add_object_patch("model_sampling", ms)
|
||||
return (m, )
|
||||
|
||||
|
||||
class ModelComputeDtype:
|
||||
SEARCH_ALIASES = ["model precision", "change dtype"]
|
||||
@classmethod
|
||||
@ -327,6 +350,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingSD3": ModelSamplingSD3,
|
||||
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||
"ModelSamplingFlux": ModelSamplingFlux,
|
||||
"ModelNoiseScale": ModelNoiseScale,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
"ModelComputeDtype": ModelComputeDtype,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user