Merge branch 'master' into cursor/mark-deprecated-cloud-endpoints-e81e

This commit is contained in:
Matt Miller
2026-05-12 09:11:38 -07:00
committed by GitHub
22 changed files with 1402 additions and 28 deletions

View File

@ -242,6 +242,7 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -373,6 +374,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -686,6 +688,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@ -747,6 +750,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -832,6 +836,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
old_denoised = None
h, h_last = None, None
@ -889,6 +894,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@ -1006,23 +1012,39 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0):
# s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps
# noise_clip_std: clamp injected noise to +/- N stddevs (0 disables).
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
n_steps = max(1, len(sigmas) - 1)
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_start = float(s_noise)
s_end = s_start if s_noise_end is None else float(s_noise_end)
for i in trange(n_steps, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std()
noise = noise.clamp(min=-clip_val, max=clip_val)
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
s_noise_i = s_start + (s_end - s_start) * t
if s_noise_i != 1.0:
noise = noise * s_noise_i
x = model_sampling.noise_scaling(sigmas[i + 1], noise, x)
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
@ -1249,6 +1271,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
uncond_denoised = None
@ -1296,6 +1319,7 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
temp = [0]
def post_cfg_function(args):
@ -1371,6 +1395,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -1504,6 +1529,7 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
def default_er_sde_noise_scaler(x):
@ -1574,9 +1600,10 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1645,9 +1672,10 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1713,6 +1741,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)

View File

@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -0,0 +1,41 @@
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
mask the general-purpose path would build (~500 MB at T~16K) and lets the
gen half hit the user's preferred backend via optimized_attention.
"""
import torch
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
def make_two_pass_attention(ar_len: int, transformer_options=None):
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes.
"""
def two_pass_attention(q, k, v, heads, **kwargs):
B, H, T, D = q.shape
if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
elif ar_len >= T:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
elif ar_len <= 0:
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
else:
out_ar = comfy.ops.scaled_dot_product_attention(
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
attn_mask=None, dropout_p=0.0, is_causal=True,
)
out_gen = optimized_attention(
q[:, :, ar_len:], k, v, heads,
mask=None, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out = torch.cat([out_ar, out_gen], dim=2)
return out.transpose(1, 2).reshape(B, T, H * D)
return two_pass_attention

View File

@ -0,0 +1,230 @@
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
Each ref image goes through two paths: a 32x32 patchified stream concatenated
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
into input_ids at <|image_pad|> positions.
"""
from typing import List
import torch
import comfy.utils
from comfy.text_encoders.qwen_vl import process_qwen2vl_images
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_tensor)
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
VIT_PATCH = 16
VIT_MERGE = 2
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
def prepare_ref_images(
ref_images: List[torch.Tensor],
target_h: int,
target_w: int,
device: torch.device,
dtype: torch.dtype,
):
"""Build the dual-path tensors for K reference images at (target_h, target_w).
Returns None for K=0, else a dict with ref_patches, ref_pixel_values,
ref_image_grid_thw, per_ref_vit_tokens, per_ref_patch_grids.
"""
K = len(ref_images)
if K == 0:
return None
max_size = ref_max_size(max(target_h, target_w), K)
cis = cond_image_size(K)
refs_t = [img[0].clamp(0, 1).permute(2, 0, 1).unsqueeze(0).contiguous().float() for img in ref_images]
refs_t = [resize_tensor(t, max_size, PATCH_SIZE) for t in refs_t]
# 32-patch path.
ref_patches_per = []
per_ref_patch_grids = []
for t in refs_t:
t_norm = (t.squeeze(0) - 0.5) / 0.5 # (3, H, W) in [-1, 1]
h_p, w_p = t_norm.shape[-2] // PATCH_SIZE, t_norm.shape[-1] // PATCH_SIZE
per_ref_patch_grids.append((h_p, w_p))
patches = (
t_norm.reshape(3, h_p, PATCH_SIZE, w_p, PATCH_SIZE)
.permute(1, 3, 0, 2, 4)
.reshape(h_p * w_p, 3 * PATCH_SIZE * PATCH_SIZE)
)
ref_patches_per.append(patches)
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
# ViT path.
refs_vlm_t = []
for t in refs_t:
_, _, h, w = t.shape
cond_w, cond_h = calculate_dimensions(cis, w / h)
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
refs_vlm_t.append(comfy.utils.common_upscale(t, cond_w, cond_h, "lanczos", "disabled"))
pv_list, grid_list, per_ref_vit_tokens = [], [], []
for t_v in refs_vlm_t:
pv, grid_thw = process_qwen2vl_images(
t_v.permute(0, 2, 3, 1),
min_pixels=0, max_pixels=10**12,
patch_size=VIT_PATCH, merge_size=VIT_MERGE,
image_mean=VIT_IMAGE_MEAN, image_std=VIT_IMAGE_STD,
)
grid_thw = grid_thw[0]
pv_list.append(pv.to(device=device, dtype=dtype))
grid_list.append(grid_thw.to(device=device))
# Post-merge token count = number of <|image_pad|> tokens this image expands to in input_ids.
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
return {
"ref_patches": ref_patches,
"ref_pixel_values": torch.cat(pv_list, dim=0),
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
"per_ref_vit_tokens": per_ref_vit_tokens,
"per_ref_patch_grids": per_ref_patch_grids,
}
def build_ref_input_ids(
text_input_ids: torch.Tensor,
per_ref_vit_tokens: List[int],
image_token_id: int,
vision_start_id: int,
vision_end_id: int,
):
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
after the [im_start, user, \\n] prefix (matches original chat template).
"""
ids = text_input_ids[0].tolist()
inserted = []
for n_pad in per_ref_vit_tokens:
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
def build_extra_conds(
text_input_ids: torch.Tensor,
noise: torch.Tensor,
ref_images: List[torch.Tensor] = None,
target_patch_size: int = 32,
):
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
input_ids (with ref-vision tokens spliced in for the edit/IP path),
position_ids (MRoPE), token_types, vinput_mask, plus the ref
dual-path tensors when refs are provided.
"""
from .utils import get_rope_index_fix_point
from comfy.text_encoders.hidream_o1 import (
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
if text_input_ids.dim() == 1:
text_input_ids = text_input_ids.unsqueeze(0)
text_input_ids = text_input_ids.long().to(noise.device)
B = noise.shape[0]
if text_input_ids.shape[0] == 1 and B > 1:
text_input_ids = text_input_ids.expand(B, -1)
H, W = noise.shape[-2], noise.shape[-1]
h_p, w_p = H // target_patch_size, W // target_patch_size
image_len = h_p * w_p
image_grid_thw_tgt = torch.tensor(
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
)
out = {}
if ref_images:
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
text_input_ids = build_ref_input_ids(
text_input_ids, ref["per_ref_vit_tokens"],
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
new_txt_len = text_input_ids.shape[1]
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
# block in the position-id stream after the noised target.
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
tgt_vision[:, 0] = VISION_START_ID
ref_vision_blocks = []
for rl in ref_grid_lengths:
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
blk[:, 0] = VISION_START_ID
ref_vision_blocks.append(blk)
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
total_ref_patches_len = sum(ref_grid_lengths)
total_len = new_txt_len + image_len + total_ref_patches_len
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
K = len(ref_images)
igthw_cond = ref["ref_image_grid_thw"].clone()
igthw_cond[:, 1] //= 2
igthw_cond[:, 2] //= 2
image_grid_thw_ref = torch.tensor(
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
dtype=torch.long, device=text_input_ids.device,
)
igthw_all = torch.cat([
igthw_cond.to(text_input_ids.device),
image_grid_thw_tgt,
image_grid_thw_ref,
], dim=0)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=igthw_all,
attention_mask=None,
skip_vision_start_token=[0] * K + [1] + [1] * K,
fix_point=4096,
)
# tms + target_image + ref_patches are all gen.
tms_pos = new_txt_len - 1
ar_len = tms_pos
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, tms_pos:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, new_txt_len:] = True
# Leading batch dim sidesteps CONDRegular.process_cond's repeat_to_batch_size truncation
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
out["ref_patches"] = ref["ref_patches"]
else:
# T2I: text + noised target only, vision_start replaces the first image token
txt_len = text_input_ids.shape[1]
total_len = txt_len + image_len
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
vision_tokens[:, 0] = VISION_START_ID
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
attention_mask=None,
skip_vision_start_token=[1],
)
ar_len = txt_len - 1
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, ar_len:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, txt_len:] = True
out["input_ids"] = text_input_ids
out["position_ids"] = position_ids[:, 0].unsqueeze(0) # Collapse position_ids batch and add a leading dim so CONDRegular's batch-resize doesn't truncate the 3-axis MRoPE dim
out["token_types"] = token_types
out["vinput_mask"] = vinput_mask
out["ar_len"] = ar_len
return out

View File

@ -0,0 +1,306 @@
"""HiDream-O1-Image transformer.
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
processes a unified text+image sequence, and 32x32 patch embed/unembed
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
mergers go unused — their weights are dropped at load.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import einops
import torch
import torch.nn as nn
import comfy.patcher_extension
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.text_encoders.llama import Llama2_
from comfy.text_encoders.qwen35 import Qwen35VisionModel
from .attention import make_two_pass_attention
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
PATCH_SIZE = 32
@dataclass
class HiDreamO1TextConfig:
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
head_dim: int = 128
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 5000000.0
rope_scale: Optional[float] = None
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
interleaved_mrope: bool = True
transformer_type: str = "llama"
rms_norm_add: bool = False
mlp_activation: str = "silu"
qkv_bias: bool = False
q_norm: str = "gemma3"
k_norm: str = "gemma3"
final_norm: bool = True
lm_head: bool = False
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
QWEN3VL_VISION_DEFAULTS = dict(
hidden_size=1152,
num_heads=16,
intermediate_size=4304,
depth=27,
patch_size=16,
temporal_patch_size=2,
in_channels=3,
spatial_merge_size=2,
num_position_embeddings=2304,
deepstack_visual_indexes=(8, 16, 24),
out_hidden_size=4096, # final merger projects directly into LLM hidden
)
class BottleneckPatchEmbed(nn.Module):
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
super().__init__()
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.proj2(self.proj1(x))
class FinalLayer(nn.Module):
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
super().__init__()
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear(x)
class HiDreamO1Transformer(nn.Module):
"""HiDream-O1 unified pixel-level transformer."""
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
text_config_overrides=None, vision_config_overrides=None, **kwargs):
super().__init__()
self.dtype = dtype
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
if vision_config_overrides:
vision_cfg.update(vision_config_overrides)
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
self.text_config = text_cfg
self.vision_config = vision_cfg
self.hidden_size = text_cfg.hidden_size
self.patch_size = PATCH_SIZE
self.in_channels = 3
self.tms_token_id = TMS_TOKEN_ID
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
self.t_embedder1 = TimestepEmbedder(
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
)
self.x_embedder = BottleneckPatchEmbed(
patch_size=self.patch_size, in_chans=self.in_channels,
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
bias=True, device=device, dtype=dtype, ops=operations,
)
self.final_layer2 = FinalLayer(
text_cfg.hidden_size, patch_size=self.patch_size,
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
)
self._visual_cache = None
self._kv_cache_entries = []
def clear_kv_cache(self):
self._kv_cache_entries = []
self._visual_cache = None
def forward(self, x, timesteps, context=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None,
vinput_mask=None, ar_len=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs):
"""Returns flow-match velocity (x - x_pred) / sigma"""
if input_ids is None or position_ids is None:
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
B, _, H, W = x.shape
h_p, w_p = H // self.patch_size, W // self.patch_size
tgt_image_len = h_p * w_p
z = einops.rearrange(
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
p1=self.patch_size, p2=self.patch_size,
)
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
if ref_pixel_values is not None and ref_image_grid_thw is not None:
# ViT output is constant across sampling steps within a generation
# identity-key by the input tensor so refs don't recompute every step.
cached = self._visual_cache
if cached is not None and cached[0] is ref_pixel_values:
image_embeds = cached[1]
else:
ref_pv = ref_pixel_values.to(inputs_embeds.device)
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
# extra_conds wraps with a leading batch dim; refs are model-level so [0] always recovers them.
if ref_pv.dim() == 3:
ref_pv = ref_pv[0]
if ref_grid.dim() == 3:
ref_grid = ref_grid[0]
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
self._visual_cache = (ref_pixel_values, image_embeds)
# image_pad positions identical across batch (input_ids shared cond/uncond).
image_idx = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
if image_idx.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Image-token count {image_idx.shape[0]} != ViT output count "
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
)
inputs_embeds[:, image_idx] = image_embeds.unsqueeze(0).expand(B, -1, -1)
sigma = timesteps.float() / 1000.0
t_pixeldit = 1.0 - sigma
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
# extra_conds stores position_ids as (1, 3, T); process_cond repeats dim 0 to B. Take row 0.
freqs_cis = self.language_model.compute_freqs_cis(position_ids[0].to(x.device), x.device)
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
two_pass_attn = make_two_pass_attention(ar_len, transformer_options=transformer_options)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.language_model.layers)
transformer_options["block_type"] = "double"
# Cache prefix K/V across steps. Key includes input_ids (prompt), ref_id
# (refs scatter into inputs_embeds), and position_ids (RoPE baked into cached K).
can_cache = not blocks_replace and ar_len > 0
cache_len = ar_len if can_cache else 0
ref_id = id(ref_pixel_values) if ref_pixel_values is not None else None
pos_ids_key = position_ids[..., :cache_len] if can_cache else position_ids
cache_entries = self._kv_cache_entries
# Drop stale entries from a previous device (model was unloaded and reloaded).
if cache_entries and cache_entries[0]["input_ids"].device != input_ids.device:
cache_entries = []
self._kv_cache_entries = []
kv_cache = None
if can_cache:
for entry in cache_entries:
ck = entry["input_ids"]
ep = entry["position_ids"]
if (entry["cache_len"] == cache_len
and ck.shape == input_ids.shape and torch.equal(ck, input_ids)
and entry["ref_id"] == ref_id
and ep.shape == pos_ids_key.shape and torch.equal(ep, pos_ids_key)):
kv_cache = entry
break
if kv_cache is not None:
# Hot path: project Q/K/V only for fresh positions; past_key_value prepends cached AR K/V.
hidden_states = inputs_embeds[:, cache_len:]
sliced_freqs = tuple(t[..., cache_len:, :] for t in freqs_cis)
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
K_i, V_i = kv_cache["kv"][i]
hidden_states, _ = layer(
x=hidden_states, attention_mask=None, freqs_cis=sliced_freqs, optimized_attention=two_pass_attn,
past_key_value=(K_i, V_i, cache_len),
)
else:
# Cold path: run full sequence; if cacheable, snapshot K/V at AR positions.
snapshots = [] if can_cache else None
past_kv_cold = () if can_cache else None
hidden_states = inputs_embeds
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args, _layer=layer):
out = {}
out["x"], _ = _layer(
x=args["x"], attention_mask=args.get("attention_mask"),
freqs_cis=args["freqs_cis"], optimized_attention=args["optimized_attention"],
past_key_value=None,
)
return out
out = blocks_replace[("double_block", i)](
{"x": hidden_states, "attention_mask": None,
"freqs_cis": freqs_cis, "optimized_attention": two_pass_attn,
"transformer_options": transformer_options},
{"original_block": block_wrap},
)
hidden_states = out["x"]
else:
hidden_states, present_kv = layer(
x=hidden_states, attention_mask=None,
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
past_key_value=past_kv_cold,
)
if snapshots is not None:
K, V, _ = present_kv
snapshots.append((K[:, :, :cache_len].contiguous(),
V[:, :, :cache_len].contiguous()))
if snapshots is not None:
# Cap at 2 entries (cond + uncond). Multi-cond workflows LRU-evict.
new_entry = {
"input_ids": input_ids.clone(),
"cache_len": cache_len,
"kv": snapshots,
"ref_id": ref_id,
"position_ids": pos_ids_key.clone(),
}
self._kv_cache_entries = (cache_entries + [new_entry])[-2:]
if self.language_model.norm is not None:
hidden_states = self.language_model.norm(hidden_states)
# Slice target-image positions before the final projection so the Linear only runs on tgt_image_len tokens.
# In the hot path hidden_states starts at original position cache_len, so masks/indices shift by cache_len.
sliced_offset = cache_len if kv_cache is not None else 0
if vinput_mask is not None:
vmask = vinput_mask.to(x.device).bool()
if sliced_offset > 0:
vmask = vmask[:, sliced_offset:]
target_hidden = hidden_states[vmask].view(B, -1, hidden_states.shape[-1])[:, :tgt_image_len]
else:
txt_seq_len = input_ids.shape[1]
start = txt_seq_len - sliced_offset
target_hidden = hidden_states[:, start:start + tgt_image_len]
x_pred_tgt = self.final_layer2(target_hidden)
# fp32 final subtraction, bf16 here noticeably degrades samples.
x_pred_img = einops.rearrange(
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
)
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)

View File

@ -0,0 +1,173 @@
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
lets the target image and patchified ref images share spatial RoPE positions
despite living at different sequence indices — same 2D image plane.
"""
import math
from typing import Optional
import torch
PATCH_SIZE = 32
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
def resize_tensor(img_t, image_size, patch_size=16):
"""img_t: (1, 3, H, W) float [0, 1]. Fit to image_size**2 area, patch-aligned, center-cropped."""
while min(img_t.shape[-2], img_t.shape[-1]) >= 2 * image_size: # Pre-halves with 2x2 box averaging while the image is still very large
img_t = torch.nn.functional.avg_pool2d(img_t, kernel_size=2, stride=2)
_, _, height, width = img_t.shape
m = patch_size
s_max = image_size * image_size
scale = math.sqrt(s_max / (width * height))
candidates = [
(round(width * scale) // m * m, round(height * scale) // m * m),
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
]
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
new_size = candidates[-1]
for c in candidates:
if c[0] * c[1] <= s_max:
new_size = c
break
new_w, new_h = new_size
s1 = width / new_w
s2 = height / new_h
if s1 < s2:
resize_w, resize_h = new_w, round(height / s1)
else:
resize_w, resize_h = round(width / s2), new_h
img_t = torch.nn.functional.interpolate(img_t, size=(resize_h, resize_w), mode="bicubic")
top = (resize_h - new_h) // 2
left = (resize_w - new_w) // 2
return img_t[..., top:top + new_h, left:left + new_w]
def calculate_dimensions(max_size, ratio):
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
width = math.sqrt(max_size * max_size * ratio)
height = width / ratio
width = int(width / 32) * 32
height = int(height / 32) * 32
return width, height
def ref_max_size(target_max_dim, k):
"""K-dependent ref-image max dim before patchifying."""
if k == 1:
return target_max_dim
if k == 2:
return target_max_dim * 48 // 64
if k <= 4:
return target_max_dim // 2
if k <= 8:
return target_max_dim * 24 // 64
return target_max_dim // 4
def cond_image_size(k):
"""K-dependent ViT-side image size."""
if k <= 4:
return CONDITION_IMAGE_SIZE
if k <= 8:
return CONDITION_IMAGE_SIZE * 48 // 64
return CONDITION_IMAGE_SIZE // 2
def get_rope_index_fix_point(
spatial_merge_size: int,
image_token_id: int,
vision_start_token_id: int,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
skip_vision_start_token=None,
fix_point: int = 4096,
):
mrope_position_deltas = []
if input_ids is not None and image_grid_thw is not None:
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1],
dtype=input_ids.dtype, device=input_ids.device,
)
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids_b in enumerate(total_input_ids):
fp = fix_point
image_index = 0
input_ids_b = input_ids_b[attention_mask[i] == 1]
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
vision_tokens = input_ids_b[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
input_tokens = input_ids_b.tolist()
llm_pos_ids_list = []
st = 0
remain_images = image_nums
for _ in range(image_nums):
if image_token_id in input_tokens and remain_images > 0:
ed = input_tokens.index(image_token_id, st)
else:
ed = len(input_tokens) + 1
t = image_grid_thw[image_index][0]
h = image_grid_thw[image_index][1]
w = image_grid_thw[image_index][2]
image_index += 1
remain_images -= 1
llm_grid_t = t.item()
llm_grid_h = h.item() // spatial_merge_size
llm_grid_w = w.item() // spatial_merge_size
text_len = ed - st
text_len -= skip_vision_start_token[image_index - 1]
text_len = max(0, text_len)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
if skip_vision_start_token[image_index - 1]:
if fp > 0:
fp = fp - st_idx
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fp + st_idx)
fp = 0
else:
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas

View File

@ -97,12 +97,14 @@ def load_lora(lora, to_load, log_missing=True):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
prefix_set = set()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
prefix_set.add(k.split('.')[0])
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
@ -163,6 +165,13 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
if len(prefix_set) == 1:
full_prefix = "{}.transformer.model.".format(next(iter(prefix_set))) # kohya anima and maybe other single TE models that use a single llama arch based te
for k in sdk:
if k.endswith(".weight"):
if k.startswith(full_prefix):
l_key = k[len(full_prefix):-len(".weight")]
key_map["lora_te_{}".format(l_key.replace(".", "_"))] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:

View File

@ -58,6 +58,8 @@ import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.model_management
import comfy.patcher_extension
@ -1674,6 +1676,32 @@ class HiDream(BaseModel):
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
return out
class HiDreamO1(BaseModel):
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
PATCH_SIZE = 32
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
text_input_ids = kwargs.get("text_input_ids", None)
noise = kwargs.get("noise", None)
if text_input_ids is None or noise is None:
return out
conds = build_extra_conds(
text_input_ids, noise,
ref_images=kwargs.get("reference_latents", None),
target_patch_size=self.PATCH_SIZE,
)
for k, v in conds.items():
# ar_len is a Python int (precomputed to avoid a GPU sync in forward).
cls = comfy.conds.CONDConstant if k == "ar_len" else comfy.conds.CONDRegular
out[k] = cls(v)
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter):
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
class LazyCastingQuantizedParam:
def __init__(self, model, key):
self.model = model
self.key = key
self.cpu_state_dict = None
def state_dict_tensor(self, state_dict_key):
if self.cpu_state_dict is None:
weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True)
self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()}
return self.cpu_state_dict[state_dict_key]
class LazyCastingParamPiece(torch.nn.Parameter):
def __new__(cls, caster, state_dict_key, tensor):
return super().__new__(cls, tensor)
def __init__(self, caster, state_dict_key, tensor):
self.caster = caster
self.state_dict_key = state_dict_key
@property
def device(self):
return CustomTorchDevice
def to(self, *args, **kwargs):
caster = self.caster
del self.caster
return caster.state_dict_tensor(self.state_dict_key)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@ -1463,20 +1494,37 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
original_state_dict = self.model.diffusion_model.state_dict()
unet_state_dict = {}
keys = list(original_state_dict)
while len(keys) > 0:
k = keys.pop(0)
v = original_state_dict[k]
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
unet_state_dict[k] = v
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
except:
unet_state_dict[k] = v
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
unet_state_dict[k] = v
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
weight = comfy.utils.get_attr(self.model, key)
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
qt_state_dict = weight.state_dict(k)
caster = LazyCastingQuantizedParam(self, key)
for group_key in (x for x in qt_state_dict if x in original_state_dict):
if group_key in keys:
keys.remove(group_key)
unet_state_dict.pop(group_key, "")
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
continue
unet_state_dict[k] = LazyCastingParam(self, key, weight)
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):

View File

@ -93,7 +93,8 @@ class CONST:
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
s = getattr(self, "noise_scale", 1.0)
return sigma * (s * noise) + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = reshape_sigma(sigma, latent.ndim)
@ -288,7 +289,11 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
self.set_noise_scale(sampling_settings.get("noise_scale", 1.0))
self.set_parameters(
shift=sampling_settings.get("shift", 1.0),
multiplier=sampling_settings.get("multiplier", 1000),
)
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
@ -296,6 +301,9 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
def set_noise_scale(self, noise_scale):
self.noise_scale = float(noise_scale)
@property
def sigma_min(self):
return self.sigmas[0]

View File

@ -1285,7 +1285,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]

View File

@ -239,7 +239,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
te_disable_dynamic = disable_dynamic or getattr(self.cond_stage_model, "disable_offload", False)
ModelPatcher = comfy.model_patcher.ModelPatcher if te_disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
@ -776,6 +777,7 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
self.disable_offload = True
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:

View File

@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
from . import supported_models_base
from . import latent_formats
@ -1431,6 +1432,50 @@ class HiDream(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None # TODO
class HiDreamO1(supported_models_base.BASE):
unet_config = {
"image_model": "hidream_o1",
}
sampling_settings = {
"shift": 3.0,
"noise_scale": 8.0,
}
latent_format = latent_formats.HiDreamO1Pixel
memory_usage_factor = 0.6
# fp16 not supported: LM MLP down_proj activations fp16 overflow, causing NaNs
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
optimizations = {"fp8": False}
def get_model(self, state_dict, prefix="", device=None):
return model_base.HiDreamO1(self, device=device)
def process_unet_state_dict(self, state_dict):
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
for key in list(state_dict.keys()):
if "visual.deepstack_merger_list" in key:
del state_dict[key]
return state_dict
def process_vae_state_dict(self, state_dict):
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
return {"pixel_space_vae": torch.tensor(1.0)}
def process_clip_state_dict(self, state_dict):
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
comfy.text_encoders.hidream_o1.HiDreamO1TE,
)
class Chroma(supported_models_base.BASE):
unet_config = {
"image_model": "chroma",
@ -2018,6 +2063,7 @@ models = [
Hunyuan3Dv2,
Hunyuan3Dv2_1,
HiDream,
HiDreamO1,
Chroma,
ChromaRadiance,
ACEStep,

View File

@ -0,0 +1,119 @@
"""HiDream-O1-Image tokenizer-only text encoder.
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
module just tokenizes the prompt into text_input_ids and emits them as
conditioning. Position ids / token_types / vinput_mask depend on target H/W
and are built later in model_base.HiDreamO1.extra_conds.
"""
import os
import torch
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
# Qwen3-VL special tokens
IM_START_ID = 151644
IM_END_ID = 151645
ASSISTANT_ID = 77091
USER_ID = 872
NEWLINE_ID = 198
VISION_START_ID = 151652
VISION_END_ID = 151653
IMAGE_TOKEN_ID = 151655
VIDEO_TOKEN_ID = 151656
# HiDream-O1-specific tokens
BOI_TOKEN_ID = 151669
BOR_TOKEN_ID = 151670
EOR_TOKEN_ID = 151671
BOT_TOKEN_ID = 151672
TMS_TOKEN_ID = 151673
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
)
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_size=4096,
embedding_key="hidream_o1",
tokenizer_class=Qwen2Tokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_token=151643,
tokenizer_data=tokenizer_data,
)
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
Image tokens get spliced in at sample time once target H/W is known.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="hidream_o1",
tokenizer=HiDreamO1QwenTokenizer,
)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text_tokens_dict = super().tokenize_with_weights(
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
text_tuples = text_tokens_dict["hidream_o1"][0]
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
def tok(tid):
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
suffix = [
tok(IM_END_ID), tok(NEWLINE_ID),
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
]
full = prefix + list(text_tuples) + suffix
return {"hidream_o1": [full]}
class HiDreamO1TE(torch.nn.Module):
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in diffusion_model does the actual encoding."""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = {torch.float32}
self.disable_offload = True # skips dynamic VRAM management for this zero-parameter module
self.device = torch.device("cpu") if device is None else torch.device(device)
def encode_token_weights(self, token_weight_pairs):
tok_pairs = token_weight_pairs["hidream_o1"][0]
ids = [int(t[0]) for t in tok_pairs]
input_ids = torch.tensor([ids], dtype=torch.long)
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
# plumbing; the model reads text_input_ids out of `extra` instead.
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
extra = {"text_input_ids": input_ids}
return cross_attn, None, extra
def load_sd(self, sd):
return []
def get_sd(self):
return {}
def reset_clip_options(self):
pass
def set_clip_options(self, options):
pass

View File

@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
if not isinstance(theta, list):
theta = [theta]
@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
freqs_inter = freqs[0].clone()
for axis_idx, offset in ((1, 1), (2, 2)):
length = rope_dims[axis_idx] * 3
idx = slice(offset, length, 3)
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
cos = emb.cos().unsqueeze(0)
sin = emb.sin().unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):

View File

@ -451,9 +451,8 @@ class Qwen35VisionPatchEmbed(nn.Module):
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
return self.proj(x).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
@ -651,7 +650,7 @@ class Qwen35VisionModel(nn.Module):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)

View File

@ -143,7 +143,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
url = await upload_image_to_comfyapi(cls, reference_images[key], mime_type="image/png")
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
@ -252,7 +252,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
image_url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
response = await sync_op(
cls,

View File

@ -86,6 +86,37 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
return x
class SamplerLCM(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCM",
category="sampling/samplers",
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
inputs=[
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the first step (1.0 = match training)."),
io.Float.Input("s_noise_end", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the last step. Set equal to s_noise for a constant schedule."),
io.Float.Input("noise_clip_std", default=0.0, min=0.0, max=10.0, step=0.01,
tooltip="Clamp per-step noise to +/- N*std. 0 disables."),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, s_noise, s_noise_end, noise_clip_std) -> io.NodeOutput:
sampler = comfy.samplers.ksampler(
"lcm",
{
"s_noise": float(s_noise),
"s_noise_end": float(s_noise_end),
"noise_clip_std": float(noise_clip_std),
},
)
return io.NodeOutput(sampler)
class SamplerEulerCFGpp(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
@ -114,6 +145,7 @@ class AdvancedSamplersExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SamplerLCMUpscale,
SamplerLCM,
SamplerEulerCFGpp,
]

View File

@ -0,0 +1,256 @@
from typing_extensions import override
import torch
import comfy.model_management
import comfy.patcher_extension
import node_helpers
from comfy_api.latest import ComfyExtension, io
class EmptyHiDreamO1LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. The model was "
"trained at ~4 megapixels; lower resolutions go off-distribution "
"and quality regresses noticeably. Trained resolutions: "
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
),
inputs=[
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="batch_size", default=1, min=1, max=64),
],
outputs=[io.Latent().Output()],
)
@classmethod
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
latent = torch.zeros(
(batch_size, 3, height, width),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class HiDreamO1ReferenceImages(io.ComfyNode):
"""Attach reference images to both positive and negative conditioning."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."
),
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 11)],
min=1,
),
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
),
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
return io.NodeOutput(positive, negative)
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
PATCH_SIZE = 32
EDGE_FEATHER = 4
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
SHIFTS_BY_PATTERN = {
("single_shift", 2): [(0, 0), (16, 16)],
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
(8, 8), (24, 8), (8, 24), (24, 24)],
("symmetric", 2): [(-8, -8), (8, 8)],
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
(-4, -4), (12, -4), (-4, 12), (12, 12)],
}
RAMP_LEVELS = {
"2": [2],
"4": [4],
"ramp_2_4": [2, 4],
"ramp_2_4_8": [2, 4, 8],
}
@staticmethod
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
"""size x size Hann tile peaking at (cy, cx) within a patch."""
half = size // 2
yy = torch.arange(size).view(size, 1)
xx = torch.arange(size).view(1, size)
dy = ((yy - cy + half) % size) - half
dx = ((xx - cx + half) % size) - half
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1PatchSeamSmoothing",
display_name="HiDream-O1 Patch Seam Smoothing",
category="advanced/model",
is_experimental=True,
description=(
"Average the model output across multiple shifted patch-grid "
"positions during the late portion of sampling. Cancels seams."
),
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
),
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress at which the blend turns OFF.",
),
io.Combo.Input(
id="pattern",
options=["single_shift", "symmetric"],
default="single_shift",
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
),
io.Combo.Input(
id="passes",
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
default="2",
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
),
io.Combo.Input(
id="blend",
options=["average", "window", "median"],
default="average",
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
),
],
outputs=[io.Model.Output()],
)
@classmethod
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
if strength <= 0.0 or end_percent <= start_percent:
return io.NodeOutput(model)
P = cls.PATCH_SIZE
half = P // 2
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
if blend == "window":
window_tile_levels = [
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
for lst in shift_levels
]
else:
window_tile_levels = [None] * len(shift_levels)
m = model.clone()
model_sampling = m.get_model_object("model_sampling")
multiplier = float(model_sampling.multiplier)
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
edge_ramp_cache: dict = {}
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
key = (H, W, device, dtype)
cached = edge_ramp_cache.get(key)
if cached is not None:
return cached
feather = cls.EDGE_FEATHER
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
y_mask = ((ys - P) / feather).clamp(0, 1)
x_mask = ((xs - P) / feather).clamp(0, 1)
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
edge_ramp_cache[key] = ramp
return ramp
def smoothing_wrapper(executor, *args, **kwargs):
x = args[0]
t = float(args[1][0])
pred = executor(*args, **kwargs)
if not (end_t <= t <= start_t):
return pred
# Pick shift-level by sigma phase across the gated range.
if len(shift_levels) == 1:
level_idx = 0
else:
phase = (start_t - t) / max(start_t - end_t, 1e-8)
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
shifts = shift_levels[level_idx]
window_tiles = window_tile_levels[level_idx]
preds = []
for sy, sx in shifts:
if sy == 0 and sx == 0:
preds.append(pred)
continue
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
_, _, _, H, W = stacked.shape
if blend == "window":
N = stacked.shape[0]
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
sum_w = w.sum(dim=0, keepdim=True)
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
elif blend == "median":
avg = torch.median(stacked, dim=0).values
else:
avg = stacked.mean(dim=0)
# Mask out the P-px wraparound contamination strip at each edge.
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
return pred * (1.0 - mask * strength) + avg * (mask * strength)
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
return io.NodeOutput(m)
class HiDreamO1Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyHiDreamO1LatentImage,
HiDreamO1ReferenceImages,
HiDreamO1PatchSeamSmoothing,
]
async def comfy_entrypoint() -> HiDreamO1Extension:
return HiDreamO1Extension()

View File

@ -300,6 +300,29 @@ class RescaleCFG:
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
class ModelNoiseScale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"noise_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 64.0, "step": 0.01,
"tooltip": "Absolute training noise scale. For example HiDream-O1 base: 8.0, dev: 7.5."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, noise_scale):
m = model.clone()
original = m.model.model_sampling
ms = type(original)(m.model.model_config)
ms.set_parameters(shift=original.shift, multiplier=original.multiplier)
ms.set_noise_scale(noise_scale)
m.add_object_patch("model_sampling", ms)
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
@ -327,6 +350,7 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"ModelNoiseScale": ModelNoiseScale,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
}

View File

@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_sam3.py",
"nodes_void.py",
"nodes_wandancer.py",
"nodes_hidream_o1.py",
]
import_failed = []